Add support for microcontroller-scale audio speech models

PiperOrigin-RevId: 204204001
This commit is contained in:
Pete Warden 2018-07-11 15:48:43 -07:00 committed by TensorFlower Gardener
parent f2fa55c8d2
commit 72e0e3d383
9 changed files with 562 additions and 189 deletions

View File

@ -56,6 +56,7 @@ tf_py_test(
srcs = ["input_data_test.py"],
additional_deps = [
":input_data",
":models",
"//tensorflow/python:client_testlib",
],
)

View File

@ -54,7 +54,7 @@ FLAGS = None
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
clip_stride_ms, window_size_ms, window_stride_ms,
dct_coefficient_count, model_architecture):
feature_bin_count, model_architecture, preprocess):
"""Creates an audio model with the nodes needed for inference.
Uses the supplied arguments to create a model, and inserts the input and
@ -67,14 +67,19 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
clip_stride_ms: How often to run recognition. Useful for models with cache.
window_size_ms: Time slice duration to estimate frequencies from.
window_stride_ms: How far apart time slices should be.
dct_coefficient_count: Number of frequency bands to analyze.
feature_bin_count: Number of frequency bands to analyze.
model_architecture: Name of the kind of model to generate.
preprocess: How the spectrogram is processed to produce features, for
example 'mfcc' or 'average'.
Raises:
Exception: If the preprocessing mode isn't recognized.
"""
words_list = input_data.prepare_words_list(wanted_words.split(','))
model_settings = models.prepare_model_settings(
len(words_list), sample_rate, clip_duration_ms, window_size_ms,
window_stride_ms, dct_coefficient_count)
window_stride_ms, feature_bin_count, preprocess)
runtime_settings = {'clip_stride_ms': clip_stride_ms}
wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
@ -88,15 +93,25 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
window_size=model_settings['window_size_samples'],
stride=model_settings['window_stride_samples'],
magnitude_squared=True)
fingerprint_input = contrib_audio.mfcc(
spectrogram,
decoded_sample_data.sample_rate,
dct_coefficient_count=dct_coefficient_count)
fingerprint_frequency_size = model_settings['dct_coefficient_count']
fingerprint_time_size = model_settings['spectrogram_length']
reshaped_input = tf.reshape(fingerprint_input, [
-1, fingerprint_time_size * fingerprint_frequency_size
])
if preprocess == 'average':
fingerprint_input = tf.nn.pool(
tf.expand_dims(spectrogram, -1),
window_shape=[1, model_settings['average_window_width']],
strides=[1, model_settings['average_window_width']],
pooling_type='AVG',
padding='SAME')
elif preprocess == 'mfcc':
fingerprint_input = contrib_audio.mfcc(
spectrogram,
sample_rate,
dct_coefficient_count=model_settings['fingerprint_width'])
else:
raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or'
' "average")' % (preprocess))
fingerprint_size = model_settings['fingerprint_size']
reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size])
logits = models.create_model(
reshaped_input, model_settings, model_architecture, is_training=False,
@ -110,10 +125,12 @@ def main(_):
# Create the model and load its weights.
sess = tf.InteractiveSession()
create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate,
FLAGS.clip_duration_ms, FLAGS.clip_stride_ms,
FLAGS.window_size_ms, FLAGS.window_stride_ms,
FLAGS.dct_coefficient_count, FLAGS.model_architecture)
create_inference_graph(
FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
if FLAGS.quantize:
tf.contrib.quantize.create_training_graph(quant_delay=0)
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
# Turn all the variables into inline constants inside the graph and save it.
@ -155,10 +172,11 @@ if __name__ == '__main__':
default=10.0,
help='How long the stride is between spectrogram timeslices',)
parser.add_argument(
'--dct_coefficient_count',
'--feature_bin_count',
type=int,
default=40,
help='How many bins to use for the MFCC fingerprint',)
help='How many bins to use for the MFCC fingerprint',
)
parser.add_argument(
'--start_checkpoint',
type=str,
@ -176,5 +194,15 @@ if __name__ == '__main__':
help='Words to use (others will be added to an unknown label)',)
parser.add_argument(
'--output_file', type=str, help='Where to save the frozen graph.')
parser.add_argument(
'--quantize',
type=bool,
default=False,
help='Whether to train the model for eight-bit deployment')
parser.add_argument(
'--preprocess',
type=str,
default='mfcc',
help='Spectrogram processing mode. Can be "mfcc" or "average"')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -24,14 +24,62 @@ from tensorflow.python.platform import test
class FreezeTest(test.TestCase):
def testCreateInferenceGraph(self):
def testCreateInferenceGraphWithMfcc(self):
with self.test_session() as sess:
freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 30.0, 30.0, 10.0,
40, 'conv')
freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
clip_duration_ms=1000.0,
clip_stride_ms=30.0,
window_size_ms=30.0,
window_stride_ms=10.0,
feature_bin_count=40,
model_architecture='conv',
preprocess='mfcc')
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
self.assertIsNotNone(
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
ops = [node.op for node in sess.graph_def.node]
self.assertEqual(1, ops.count('Mfcc'))
def testCreateInferenceGraphWithoutMfcc(self):
with self.test_session() as sess:
freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
clip_duration_ms=1000.0,
clip_stride_ms=30.0,
window_size_ms=30.0,
window_stride_ms=10.0,
feature_bin_count=40,
model_architecture='conv',
preprocess='average')
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
self.assertIsNotNone(
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
ops = [node.op for node in sess.graph_def.node]
self.assertEqual(0, ops.count('Mfcc'))
def testFeatureBinCount(self):
with self.test_session() as sess:
freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
clip_duration_ms=1000.0,
clip_stride_ms=30.0,
window_size_ms=30.0,
window_stride_ms=10.0,
feature_bin_count=80,
model_architecture='conv',
preprocess='average')
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
self.assertIsNotNone(
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
ops = [node.op for node in sess.graph_def.node]
self.assertEqual(0, ops.count('Mfcc'))
if __name__ == '__main__':

View File

@ -87,11 +87,12 @@ def main(_):
words_list = input_data.prepare_words_list(FLAGS.wanted_words.split(','))
model_settings = models.prepare_model_settings(
len(words_list), FLAGS.sample_rate, FLAGS.clip_duration_ms,
FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.feature_bin_count,
'mfcc')
audio_processor = input_data.AudioProcessor(
'', FLAGS.data_dir, FLAGS.silence_percentage, 10,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
FLAGS.testing_percentage, model_settings)
FLAGS.testing_percentage, model_settings, FLAGS.data_dir)
output_audio_sample_count = FLAGS.sample_rate * FLAGS.test_duration_seconds
output_audio = np.zeros((output_audio_sample_count,), dtype=np.float32)
@ -242,10 +243,11 @@ if __name__ == '__main__':
default=10.0,
help='How long the stride is between spectrogram timeslices',)
parser.add_argument(
'--dct_coefficient_count',
'--feature_bin_count',
type=int,
default=40,
help='How many bins to use for the MFCC fingerprint',)
help='How many bins to use for the MFCC fingerprint',
)
parser.add_argument(
'--wanted_words',
type=str,

View File

@ -153,14 +153,14 @@ class AudioProcessor(object):
def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
wanted_words, validation_percentage, testing_percentage,
model_settings):
model_settings, summaries_dir):
self.data_dir = data_dir
self.maybe_download_and_extract_dataset(data_url, data_dir)
self.prepare_data_index(silence_percentage, unknown_percentage,
wanted_words, validation_percentage,
testing_percentage)
self.prepare_background_data()
self.prepare_processing_graph(model_settings)
self.prepare_processing_graph(model_settings, summaries_dir)
def maybe_download_and_extract_dataset(self, data_url, dest_directory):
"""Download and extract data set tar file.
@ -325,7 +325,7 @@ class AudioProcessor(object):
if not self.background_data:
raise Exception('No background wav files were found in ' + search_path)
def prepare_processing_graph(self, model_settings):
def prepare_processing_graph(self, model_settings, summaries_dir):
"""Builds a TensorFlow graph to apply the input distortions.
Creates a graph that loads a WAVE file, decodes it, scales the volume,
@ -341,48 +341,88 @@ class AudioProcessor(object):
- time_shift_offset_placeholder_: How much to move the clip in time.
- background_data_placeholder_: PCM sample data for background noise.
- background_volume_placeholder_: Loudness of mixed-in background.
- mfcc_: Output 2D fingerprint of processed audio.
- output_: Output 2D fingerprint of processed audio.
Args:
model_settings: Information about the current model being trained.
summaries_dir: Path to save training summary information to.
Raises:
ValueError: If the preprocessing mode isn't recognized.
"""
desired_samples = model_settings['desired_samples']
self.wav_filename_placeholder_ = tf.placeholder(tf.string, [])
wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
wav_decoder = contrib_audio.decode_wav(
wav_loader, desired_channels=1, desired_samples=desired_samples)
# Allow the audio sample's volume to be adjusted.
self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, [])
scaled_foreground = tf.multiply(wav_decoder.audio,
self.foreground_volume_placeholder_)
# Shift the sample's start position, and pad any gaps with zeros.
self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2])
self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2])
padded_foreground = tf.pad(
scaled_foreground,
self.time_shift_padding_placeholder_,
mode='CONSTANT')
sliced_foreground = tf.slice(padded_foreground,
self.time_shift_offset_placeholder_,
[desired_samples, -1])
# Mix in background noise.
self.background_data_placeholder_ = tf.placeholder(tf.float32,
[desired_samples, 1])
self.background_volume_placeholder_ = tf.placeholder(tf.float32, [])
background_mul = tf.multiply(self.background_data_placeholder_,
self.background_volume_placeholder_)
background_add = tf.add(background_mul, sliced_foreground)
background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
# Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
spectrogram = contrib_audio.audio_spectrogram(
background_clamp,
window_size=model_settings['window_size_samples'],
stride=model_settings['window_stride_samples'],
magnitude_squared=True)
self.mfcc_ = contrib_audio.mfcc(
spectrogram,
wav_decoder.sample_rate,
dct_coefficient_count=model_settings['dct_coefficient_count'])
with tf.get_default_graph().name_scope('data'):
desired_samples = model_settings['desired_samples']
self.wav_filename_placeholder_ = tf.placeholder(
tf.string, [], name='wav_filename')
wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
wav_decoder = contrib_audio.decode_wav(
wav_loader, desired_channels=1, desired_samples=desired_samples)
# Allow the audio sample's volume to be adjusted.
self.foreground_volume_placeholder_ = tf.placeholder(
tf.float32, [], name='foreground_volume')
scaled_foreground = tf.multiply(wav_decoder.audio,
self.foreground_volume_placeholder_)
# Shift the sample's start position, and pad any gaps with zeros.
self.time_shift_padding_placeholder_ = tf.placeholder(
tf.int32, [2, 2], name='time_shift_padding')
self.time_shift_offset_placeholder_ = tf.placeholder(
tf.int32, [2], name='time_shift_offset')
padded_foreground = tf.pad(
scaled_foreground,
self.time_shift_padding_placeholder_,
mode='CONSTANT')
sliced_foreground = tf.slice(padded_foreground,
self.time_shift_offset_placeholder_,
[desired_samples, -1])
# Mix in background noise.
self.background_data_placeholder_ = tf.placeholder(
tf.float32, [desired_samples, 1], name='background_data')
self.background_volume_placeholder_ = tf.placeholder(
tf.float32, [], name='background_volume')
background_mul = tf.multiply(self.background_data_placeholder_,
self.background_volume_placeholder_)
background_add = tf.add(background_mul, sliced_foreground)
background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
# Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
spectrogram = contrib_audio.audio_spectrogram(
background_clamp,
window_size=model_settings['window_size_samples'],
stride=model_settings['window_stride_samples'],
magnitude_squared=True)
tf.summary.image(
'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1)
# The number of buckets in each FFT row in the spectrogram will depend on
# how many input samples there are in each window. This can be quite
# large, with a 160 sample window producing 127 buckets for example. We
# don't need this level of detail for classification, so we often want to
# shrink them down to produce a smaller result. That's what this section
# implements. One method is to use average pooling to merge adjacent
# buckets, but a more sophisticated approach is to apply the MFCC
# algorithm to shrink the representation.
if model_settings['preprocess'] == 'average':
self.output_ = tf.nn.pool(
tf.expand_dims(spectrogram, -1),
window_shape=[1, model_settings['average_window_width']],
strides=[1, model_settings['average_window_width']],
pooling_type='AVG',
padding='SAME')
tf.summary.image('shrunk_spectrogram', self.output_, max_outputs=1)
elif model_settings['preprocess'] == 'mfcc':
self.output_ = contrib_audio.mfcc(
spectrogram,
wav_decoder.sample_rate,
dct_coefficient_count=model_settings['fingerprint_width'])
tf.summary.image(
'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
else:
raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
' "average")' % (model_settings['preprocess']))
# Merge all the summaries and write them out to /tmp/retrain_logs (by
# default)
self.merged_summaries_ = tf.summary.merge_all(scope='data')
self.summary_writer_ = tf.summary.FileWriter(summaries_dir + '/data',
tf.get_default_graph())
def set_size(self, mode):
"""Calculates the number of samples in the dataset partition.
@ -418,6 +458,9 @@ class AudioProcessor(object):
Returns:
List of sample data for the transformed samples, and list of label indexes
Raises:
ValueError: If background samples are too short.
"""
# Pick one of the partitions to choose samples from.
candidates = self.data_index[mode]
@ -460,6 +503,11 @@ class AudioProcessor(object):
if use_background or sample['label'] == SILENCE_LABEL:
background_index = np.random.randint(len(self.background_data))
background_samples = self.background_data[background_index]
if len(background_samples) <= model_settings['desired_samples']:
raise ValueError(
'Background sample is too short! Need more than %d'
' samples but only %d were found' %
(model_settings['desired_samples'], len(background_samples)))
background_offset = np.random.randint(
0, len(background_samples) - model_settings['desired_samples'])
background_clipped = background_samples[background_offset:(
@ -482,7 +530,10 @@ class AudioProcessor(object):
else:
input_dict[self.foreground_volume_placeholder_] = 1
# Run the graph to produce the output audio.
data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten()
summary, data_tensor = sess.run(
[self.merged_summaries_, self.output_], feed_dict=input_dict)
self.summary_writer_.add_summary(summary)
data[i - offset, :] = data_tensor.flatten()
label_index = self.word_to_index[sample['label']]
labels[i - offset] = label_index
return data, labels

View File

@ -25,6 +25,7 @@ import tensorflow as tf
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from tensorflow.examples.speech_commands import input_data
from tensorflow.examples.speech_commands import models
from tensorflow.python.platform import test
@ -32,7 +33,7 @@ class InputDataTest(test.TestCase):
def _getWavData(self):
with self.test_session() as sess:
sample_data = tf.zeros([1000, 2])
sample_data = tf.zeros([32000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
return wav_data
@ -57,9 +58,31 @@ class InputDataTest(test.TestCase):
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
"dct_coefficient_count": 40,
"fingerprint_width": 40,
"preprocess": "mfcc",
}
def _runGetDataTest(self, preprocess, window_length_ms):
tmp_dir = self.get_temp_dir()
wav_dir = os.path.join(tmp_dir, "wavs")
os.mkdir(wav_dir)
self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
background_dir = os.path.join(wav_dir, "_background_noise_")
os.mkdir(background_dir)
wav_data = self._getWavData()
for i in range(10):
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
self._saveTestWavFile(file_path, wav_data)
model_settings = models.prepare_model_settings(
4, 16000, 1000, window_length_ms, 20, 40, preprocess)
with self.test_session() as sess:
audio_processor = input_data.AudioProcessor(
"", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
result_data, result_labels = audio_processor.get_data(
10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
self.assertEqual(10, len(result_data))
self.assertEqual(10, len(result_labels))
def testPrepareWordsList(self):
words_list = ["a", "b"]
self.assertGreater(
@ -76,8 +99,9 @@ class InputDataTest(test.TestCase):
def testPrepareDataIndex(self):
tmp_dir = self.get_temp_dir()
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
10, 10, self._model_settings())
audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
["a", "b"], 10, 10,
self._model_settings(), tmp_dir)
self.assertLess(0, audio_processor.set_size("training"))
self.assertTrue("training" in audio_processor.data_index)
self.assertTrue("validation" in audio_processor.data_index)
@ -90,7 +114,7 @@ class InputDataTest(test.TestCase):
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0)
with self.assertRaises(Exception) as e:
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
self._model_settings())
self._model_settings(), tmp_dir)
self.assertTrue("No .wavs found" in str(e.exception))
def testPrepareDataIndexMissing(self):
@ -98,7 +122,7 @@ class InputDataTest(test.TestCase):
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
with self.assertRaises(Exception) as e:
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
10, self._model_settings())
10, self._model_settings(), tmp_dir)
self.assertTrue("Expected to find" in str(e.exception))
def testPrepareBackgroundData(self):
@ -110,8 +134,9 @@ class InputDataTest(test.TestCase):
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
self._saveTestWavFile(file_path, wav_data)
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
10, 10, self._model_settings())
audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
["a", "b"], 10, 10,
self._model_settings(), tmp_dir)
self.assertEqual(10, len(audio_processor.background_data))
def testLoadWavFile(self):
@ -148,44 +173,27 @@ class InputDataTest(test.TestCase):
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
"dct_coefficient_count": 40,
"fingerprint_width": 40,
"preprocess": "mfcc",
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
10, 10, model_settings)
10, 10, model_settings, tmp_dir)
self.assertIsNotNone(audio_processor.wav_filename_placeholder_)
self.assertIsNotNone(audio_processor.foreground_volume_placeholder_)
self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_)
self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_)
self.assertIsNotNone(audio_processor.background_data_placeholder_)
self.assertIsNotNone(audio_processor.background_volume_placeholder_)
self.assertIsNotNone(audio_processor.mfcc_)
self.assertIsNotNone(audio_processor.output_)
def testGetData(self):
tmp_dir = self.get_temp_dir()
wav_dir = os.path.join(tmp_dir, "wavs")
os.mkdir(wav_dir)
self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
background_dir = os.path.join(wav_dir, "_background_noise_")
os.mkdir(background_dir)
wav_data = self._getWavData()
for i in range(10):
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
self._saveTestWavFile(file_path, wav_data)
model_settings = {
"desired_samples": 160,
"fingerprint_size": 40,
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
"dct_coefficient_count": 40,
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
10, 10, model_settings)
with self.test_session() as sess:
result_data, result_labels = audio_processor.get_data(
10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
self.assertEqual(10, len(result_data))
self.assertEqual(10, len(result_labels))
def testGetDataAverage(self):
self._runGetDataTest("average", 10)
def testGetDataAverageLongWindow(self):
self._runGetDataTest("average", 30)
def testGetDataMfcc(self):
self._runGetDataTest("mfcc", 30)
def testGetUnprocessedData(self):
tmp_dir = self.get_temp_dir()
@ -198,10 +206,11 @@ class InputDataTest(test.TestCase):
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
"dct_coefficient_count": 40,
"fingerprint_width": 40,
"preprocess": "mfcc",
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
10, 10, model_settings)
10, 10, model_settings, tmp_dir)
result_data, result_labels = audio_processor.get_unprocessed_data(
10, model_settings, "training")
self.assertEqual(10, len(result_data))

View File

@ -24,9 +24,21 @@ import math
import tensorflow as tf
def _next_power_of_two(x):
"""Calculates the smallest enclosing power of two for an input.
Args:
x: Positive float or integer number.
Returns:
Next largest power of two integer.
"""
return 1 if x == 0 else 2**(int(x) - 1).bit_length()
def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
window_size_ms, window_stride_ms,
dct_coefficient_count):
window_size_ms, window_stride_ms, feature_bin_count,
preprocess):
"""Calculates common settings needed for all models.
Args:
@ -35,10 +47,14 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
clip_duration_ms: Length of each audio clip to be analyzed.
window_size_ms: Duration of frequency analysis window.
window_stride_ms: How far to move in time between frequency windows.
dct_coefficient_count: Number of frequency bins to use for analysis.
feature_bin_count: Number of frequency bins to use for analysis.
preprocess: How the spectrogram is processed to produce features.
Returns:
Dictionary containing common settings.
Raises:
ValueError: If the preprocessing mode isn't recognized.
"""
desired_samples = int(sample_rate * clip_duration_ms / 1000)
window_size_samples = int(sample_rate * window_size_ms / 1000)
@ -48,16 +64,28 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
spectrogram_length = 0
else:
spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
fingerprint_size = dct_coefficient_count * spectrogram_length
if preprocess == 'average':
fft_bin_count = 1 + (_next_power_of_two(window_size_samples) / 2)
average_window_width = int(math.floor(fft_bin_count / feature_bin_count))
fingerprint_width = int(math.ceil(fft_bin_count / average_window_width))
elif preprocess == 'mfcc':
average_window_width = -1
fingerprint_width = feature_bin_count
else:
raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
' "average")' % (preprocess))
fingerprint_size = fingerprint_width * spectrogram_length
return {
'desired_samples': desired_samples,
'window_size_samples': window_size_samples,
'window_stride_samples': window_stride_samples,
'spectrogram_length': spectrogram_length,
'dct_coefficient_count': dct_coefficient_count,
'fingerprint_width': fingerprint_width,
'fingerprint_size': fingerprint_size,
'label_count': label_count,
'sample_rate': sample_rate,
'preprocess': preprocess,
'average_window_width': average_window_width,
}
@ -106,10 +134,14 @@ def create_model(fingerprint_input, model_settings, model_architecture,
elif model_architecture == 'low_latency_svdf':
return create_low_latency_svdf_model(fingerprint_input, model_settings,
is_training, runtime_settings)
elif model_architecture == 'tiny_conv':
return create_tiny_conv_model(fingerprint_input, model_settings,
is_training)
else:
raise Exception('model_architecture argument "' + model_architecture +
'" not recognized, should be one of "single_fc", "conv",' +
' "low_latency_conv, or "low_latency_svdf"')
' "low_latency_conv, "low_latency_svdf",' +
' or "tiny_conv"')
def load_variables_from_checkpoint(sess, start_checkpoint):
@ -152,9 +184,12 @@ def create_single_fc_model(fingerprint_input, model_settings, is_training):
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
weights = tf.Variable(
tf.truncated_normal([fingerprint_size, label_count], stddev=0.001))
bias = tf.Variable(tf.zeros([label_count]))
weights = tf.get_variable(
name='weights',
initializer=tf.truncated_normal_initializer(stddev=0.001),
shape=[fingerprint_size, label_count])
bias = tf.get_variable(
name='bias', initializer=tf.zeros_initializer, shape=[label_count])
logits = tf.matmul(fingerprint_input, weights) + bias
if is_training:
return logits, dropout_prob
@ -212,18 +247,21 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
input_frequency_size = model_settings['dct_coefficient_count']
input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
first_filter_width = 8
first_filter_height = 20
first_filter_count = 64
first_weights = tf.Variable(
tf.truncated_normal(
[first_filter_height, first_filter_width, 1, first_filter_count],
stddev=0.01))
first_bias = tf.Variable(tf.zeros([first_filter_count]))
first_weights = tf.get_variable(
name='first_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[first_filter_height, first_filter_width, 1, first_filter_count])
first_bias = tf.get_variable(
name='first_bias',
initializer=tf.zeros_initializer,
shape=[first_filter_count])
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1],
'SAME') + first_bias
first_relu = tf.nn.relu(first_conv)
@ -235,14 +273,17 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
second_filter_width = 4
second_filter_height = 10
second_filter_count = 64
second_weights = tf.Variable(
tf.truncated_normal(
[
second_filter_height, second_filter_width, first_filter_count,
second_filter_count
],
stddev=0.01))
second_bias = tf.Variable(tf.zeros([second_filter_count]))
second_weights = tf.get_variable(
name='second_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[
second_filter_height, second_filter_width, first_filter_count,
second_filter_count
])
second_bias = tf.get_variable(
name='second_bias',
initializer=tf.zeros_initializer,
shape=[second_filter_count])
second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1],
'SAME') + second_bias
second_relu = tf.nn.relu(second_conv)
@ -259,10 +300,14 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
flattened_second_conv = tf.reshape(second_dropout,
[-1, second_conv_element_count])
label_count = model_settings['label_count']
final_fc_weights = tf.Variable(
tf.truncated_normal(
[second_conv_element_count, label_count], stddev=0.01))
final_fc_bias = tf.Variable(tf.zeros([label_count]))
final_fc_weights = tf.get_variable(
name='final_fc_weights',
initializer=tf.truncated_normal_initializer,
shape=[second_conv_element_count, label_count])
final_fc_bias = tf.get_variable(
name='final_fc_bias',
initializer=tf.zeros_initializer,
shape=[label_count])
final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
@ -318,7 +363,7 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
input_frequency_size = model_settings['dct_coefficient_count']
input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
@ -327,11 +372,14 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
first_filter_count = 186
first_filter_stride_x = 1
first_filter_stride_y = 1
first_weights = tf.Variable(
tf.truncated_normal(
[first_filter_height, first_filter_width, 1, first_filter_count],
stddev=0.01))
first_bias = tf.Variable(tf.zeros([first_filter_count]))
first_weights = tf.get_variable(
name='first_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[first_filter_height, first_filter_width, 1, first_filter_count])
first_bias = tf.get_variable(
name='first_bias',
initializer=tf.zeros_initializer,
shape=[first_filter_count])
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [
1, first_filter_stride_y, first_filter_stride_x, 1
], 'VALID') + first_bias
@ -351,30 +399,42 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
flattened_first_conv = tf.reshape(first_dropout,
[-1, first_conv_element_count])
first_fc_output_channels = 128
first_fc_weights = tf.Variable(
tf.truncated_normal(
[first_conv_element_count, first_fc_output_channels], stddev=0.01))
first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
first_fc_weights = tf.get_variable(
name='first_fc_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[first_conv_element_count, first_fc_output_channels])
first_fc_bias = tf.get_variable(
name='first_fc_bias',
initializer=tf.zeros_initializer,
shape=[first_fc_output_channels])
first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 128
second_fc_weights = tf.Variable(
tf.truncated_normal(
[first_fc_output_channels, second_fc_output_channels], stddev=0.01))
second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
second_fc_weights = tf.get_variable(
name='second_fc_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[first_fc_output_channels, second_fc_output_channels])
second_fc_bias = tf.get_variable(
name='second_fc_bias',
initializer=tf.zeros_initializer,
shape=[second_fc_output_channels])
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
final_fc_weights = tf.Variable(
tf.truncated_normal(
[second_fc_output_channels, label_count], stddev=0.01))
final_fc_bias = tf.Variable(tf.zeros([label_count]))
final_fc_weights = tf.get_variable(
name='final_fc_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[second_fc_output_channels, label_count])
final_fc_bias = tf.get_variable(
name='final_fc_bias',
initializer=tf.zeros_initializer,
shape=[label_count])
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
@ -422,7 +482,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
The node is expected to produce a 2D Tensor of shape:
[batch, model_settings['dct_coefficient_count'] *
[batch, model_settings['fingerprint_width'] *
model_settings['spectrogram_length']]
with the features corresponding to the same time slot arranged contiguously,
and the oldest slot at index [:, 0], and newest at [:, -1].
@ -440,7 +500,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
input_frequency_size = model_settings['dct_coefficient_count']
input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
# Validation.
@ -462,8 +522,11 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
num_filters = rank * num_units
# Create the runtime memory: [num_filters, batch, input_time_size]
batch = 1
memory = tf.Variable(tf.zeros([num_filters, batch, input_time_size]),
trainable=False, name='runtime-memory')
memory = tf.get_variable(
initializer=tf.zeros_initializer,
shape=[num_filters, batch, input_time_size],
trainable=False,
name='runtime-memory')
# Determine the number of new frames in the input, such that we only operate
# on those. For training we do not use the memory, and thus use all frames
# provided in the input.
@ -483,8 +546,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)
# Create the frequency filters.
weights_frequency = tf.Variable(
tf.truncated_normal([input_frequency_size, num_filters], stddev=0.01))
weights_frequency = tf.get_variable(
name='weights_frequency',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[input_frequency_size, num_filters])
# Expand to add input channels dimensions.
# weights_frequency: [input_frequency_size, 1, num_filters]
weights_frequency = tf.expand_dims(weights_frequency, 1)
@ -506,8 +571,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
activations_time = new_memory
# Create the time filters.
weights_time = tf.Variable(
tf.truncated_normal([num_filters, input_time_size], stddev=0.01))
weights_time = tf.get_variable(
name='weights_time',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[num_filters, input_time_size])
# Apply the time filter on the outputs of the feature filters.
# weights_time: [num_filters, input_time_size, 1]
# outputs: [num_filters, batch, 1]
@ -524,7 +591,8 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
units_output = tf.transpose(units_output)
# Appy bias.
bias = tf.Variable(tf.zeros([num_units]))
bias = tf.get_variable(
name='bias', initializer=tf.zeros_initializer, shape=[num_units])
first_bias = tf.nn.bias_add(units_output, bias)
# Relu.
@ -536,31 +604,135 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
first_dropout = first_relu
first_fc_output_channels = 256
first_fc_weights = tf.Variable(
tf.truncated_normal([num_units, first_fc_output_channels], stddev=0.01))
first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
first_fc_weights = tf.get_variable(
name='first_fc_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[num_units, first_fc_output_channels])
first_fc_bias = tf.get_variable(
name='first_fc_bias',
initializer=tf.zeros_initializer,
shape=[first_fc_output_channels])
first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 256
second_fc_weights = tf.Variable(
tf.truncated_normal(
[first_fc_output_channels, second_fc_output_channels], stddev=0.01))
second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
second_fc_weights = tf.get_variable(
name='second_fc_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[first_fc_output_channels, second_fc_output_channels])
second_fc_bias = tf.get_variable(
name='second_fc_bias',
initializer=tf.zeros_initializer,
shape=[second_fc_output_channels])
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
final_fc_weights = tf.Variable(
tf.truncated_normal(
[second_fc_output_channels, label_count], stddev=0.01))
final_fc_bias = tf.Variable(tf.zeros([label_count]))
final_fc_weights = tf.get_variable(
name='final_fc_weights',
initializer=tf.truncated_normal(stddev=0.01),
shape=[second_fc_output_channels, label_count])
final_fc_bias = tf.get_variable(
name='final_fc_bias',
initializer=tf.zeros_initializer,
shape=[label_count])
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
else:
return final_fc
def create_tiny_conv_model(fingerprint_input, model_settings, is_training):
"""Builds a convolutional model aimed at microcontrollers.
Devices like DSPs and microcontrollers can have very small amounts of
memory and limited processing power. This model is designed to use less
than 20KB of working RAM, and fit within 32KB of read-only (flash) memory.
Here's the layout of the graph:
(fingerprint_input)
v
[Conv2D]<-(weights)
v
[BiasAdd]<-(bias)
v
[Relu]
v
[MatMul]<-(weights)
v
[BiasAdd]<-(bias)
v
This doesn't produce particularly accurate results, but it's designed to be
used as the first stage of a pipeline, running on a low-energy piece of
hardware that can always be on, and then wake higher-power chips when a
possible utterance has been found, so that more accurate analysis can be done.
During training, a dropout node is introduced after the relu, controlled by a
placeholder.
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
model_settings: Dictionary of information about the model.
is_training: Whether the model is going to be used for training.
Returns:
TensorFlow node outputting logits results, and optionally a dropout
placeholder.
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
first_filter_width = 8
first_filter_height = 10
first_filter_count = 8
first_weights = tf.get_variable(
name='first_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[first_filter_height, first_filter_width, 1, first_filter_count])
first_bias = tf.get_variable(
name='first_bias',
initializer=tf.zeros_initializer,
shape=[first_filter_count])
first_conv_stride_x = 2
first_conv_stride_y = 2
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights,
[1, first_conv_stride_y, first_conv_stride_x, 1],
'SAME') + first_bias
first_relu = tf.nn.relu(first_conv)
if is_training:
first_dropout = tf.nn.dropout(first_relu, dropout_prob)
else:
first_dropout = first_relu
first_dropout_shape = first_dropout.get_shape()
first_dropout_output_width = first_dropout_shape[2]
first_dropout_output_height = first_dropout_shape[1]
first_dropout_element_count = int(
first_dropout_output_width * first_dropout_output_height *
first_filter_count)
flattened_first_dropout = tf.reshape(first_dropout,
[-1, first_dropout_element_count])
label_count = model_settings['label_count']
final_fc_weights = tf.get_variable(
name='final_fc_weights',
initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[first_dropout_element_count, label_count])
final_fc_bias = tf.get_variable(
name='final_fc_bias',
initializer=tf.zeros_initializer,
shape=[label_count])
final_fc = (
tf.matmul(flattened_first_dropout, final_fc_weights) + final_fc_bias)
if is_training:
return final_fc, dropout_prob
else:
return final_fc

View File

@ -26,12 +26,29 @@ from tensorflow.python.platform import test
class ModelsTest(test.TestCase):
def _modelSettings(self):
return models.prepare_model_settings(
label_count=10,
sample_rate=16000,
clip_duration_ms=1000,
window_size_ms=20,
window_stride_ms=10,
feature_bin_count=40,
preprocess="mfcc")
def testPrepareModelSettings(self):
self.assertIsNotNone(
models.prepare_model_settings(10, 16000, 1000, 20, 10, 40))
models.prepare_model_settings(
label_count=10,
sample_rate=16000,
clip_duration_ms=1000,
window_size_ms=20,
window_stride_ms=10,
feature_bin_count=40,
preprocess="mfcc"))
def testCreateModelConvTraining(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(fingerprint_input,
@ -42,7 +59,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelConvInference(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits = models.create_model(fingerprint_input, model_settings, "conv",
@ -51,7 +68,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
def testCreateModelLowLatencyConvTraining(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
@ -62,7 +79,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelFullyConnectedTraining(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
@ -73,7 +90,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelBadArchitecture(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
model_settings = self._modelSettings()
with self.test_session():
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
with self.assertRaises(Exception) as e:
@ -81,6 +98,17 @@ class ModelsTest(test.TestCase):
"bad_architecture", True)
self.assertTrue("not recognized" in str(e.exception))
def testCreateModelTinyConvTraining(self):
model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
fingerprint_input, model_settings, "tiny_conv", True)
self.assertIsNotNone(logits)
self.assertIsNotNone(dropout_prob)
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
if __name__ == "__main__":
test.main()

View File

@ -98,12 +98,12 @@ def main(_):
model_settings = models.prepare_model_settings(
len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
audio_processor = input_data.AudioProcessor(
FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage,
FLAGS.unknown_percentage,
FLAGS.data_url, FLAGS.data_dir,
FLAGS.silence_percentage, FLAGS.unknown_percentage,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
FLAGS.testing_percentage, model_settings)
FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
@ -122,8 +122,25 @@ def main(_):
'lists, but are %d and %d long instead' % (len(training_steps_list),
len(learning_rates_list)))
fingerprint_input = tf.placeholder(
input_placeholder = tf.placeholder(
tf.float32, [None, fingerprint_size], name='fingerprint_input')
if FLAGS.quantize:
# TODO(petewarden): These values have been derived from the observed ranges
# of spectrogram and MFCC inputs. If the preprocessing pipeline changes,
# they may need to be updated.
if FLAGS.preprocess == 'average':
fingerprint_min = 0.0
fingerprint_max = 2048.0
elif FLAGS.preprocess == 'mfcc':
fingerprint_min = -247.0
fingerprint_max = 30.0
else:
raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or'
' "average")' % (FLAGS.preprocess))
fingerprint_input = tf.fake_quant_with_min_max_args(
input_placeholder, fingerprint_min, fingerprint_max)
else:
fingerprint_input = input_placeholder
logits, dropout_prob = models.create_model(
fingerprint_input,
@ -146,7 +163,8 @@ def main(_):
with tf.name_scope('cross_entropy'):
cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
labels=ground_truth_input, logits=logits)
tf.summary.scalar('cross_entropy', cross_entropy_mean)
if FLAGS.quantize:
tf.contrib.quantize.create_training_graph(quant_delay=0)
with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
learning_rate_input = tf.placeholder(
tf.float32, [], name='learning_rate_input')
@ -157,7 +175,9 @@ def main(_):
confusion_matrix = tf.confusion_matrix(
ground_truth_input, predicted_indices, num_classes=label_count)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', evaluation_step)
with tf.get_default_graph().name_scope('eval'):
tf.summary.scalar('cross_entropy', cross_entropy_mean)
tf.summary.scalar('accuracy', evaluation_step)
global_step = tf.train.get_or_create_global_step()
increment_global_step = tf.assign(global_step, global_step + 1)
@ -165,7 +185,7 @@ def main(_):
saver = tf.train.Saver(tf.global_variables())
# Merge all the summaries and write them out to /tmp/retrain_logs (by default)
merged_summaries = tf.summary.merge_all()
merged_summaries = tf.summary.merge_all(scope='eval')
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
sess.graph)
validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
@ -207,8 +227,11 @@ def main(_):
# Run the graph with this batch of training data.
train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
[
merged_summaries, evaluation_step, cross_entropy_mean, train_step,
increment_global_step
merged_summaries,
evaluation_step,
cross_entropy_mean,
train_step,
increment_global_step,
],
feed_dict={
fingerprint_input: train_fingerprints,
@ -364,10 +387,11 @@ if __name__ == '__main__':
default=10.0,
help='How far to move in time between spectogram timeslices.',)
parser.add_argument(
'--dct_coefficient_count',
'--feature_bin_count',
type=int,
default=40,
help='How many bins to use for the MFCC fingerprint',)
help='How many bins to use for the MFCC fingerprint',
)
parser.add_argument(
'--how_many_training_steps',
type=str,
@ -423,6 +447,16 @@ if __name__ == '__main__':
type=bool,
default=False,
help='Whether to check for invalid numbers during processing')
parser.add_argument(
'--quantize',
type=bool,
default=False,
help='Whether to train the model for eight-bit deployment')
parser.add_argument(
'--preprocess',
type=str,
default='mfcc',
help='Spectrogram processing mode. Can be "mfcc" or "average"')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)