Update speech commands training to support micro frontend
PiperOrigin-RevId: 236948121
This commit is contained in:
parent
35d0488ab8
commit
4a464440b2
@ -45,6 +45,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
@ -49,6 +49,14 @@ import input_data
|
|||||||
import models
|
import models
|
||||||
from tensorflow.python.framework import graph_util
|
from tensorflow.python.framework import graph_util
|
||||||
|
|
||||||
|
# If it's available, load the specialized feature generator. If this doesn't
|
||||||
|
# work, try building with bazel instead of running the Python script directly.
|
||||||
|
# bazel run tensorflow/examples/speech_commands:freeze_graph
|
||||||
|
try:
|
||||||
|
from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op # pylint:disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
frontend_op = None
|
||||||
|
|
||||||
FLAGS = None
|
FLAGS = None
|
||||||
|
|
||||||
|
|
||||||
@ -70,7 +78,7 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
|||||||
feature_bin_count: Number of frequency bands to analyze.
|
feature_bin_count: Number of frequency bands to analyze.
|
||||||
model_architecture: Name of the kind of model to generate.
|
model_architecture: Name of the kind of model to generate.
|
||||||
preprocess: How the spectrogram is processed to produce features, for
|
preprocess: How the spectrogram is processed to produce features, for
|
||||||
example 'mfcc' or 'average'.
|
example 'mfcc', 'average', or 'micro'.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If the preprocessing mode isn't recognized.
|
Exception: If the preprocessing mode isn't recognized.
|
||||||
@ -106,9 +114,33 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
|||||||
spectrogram,
|
spectrogram,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
dct_coefficient_count=model_settings['fingerprint_width'])
|
dct_coefficient_count=model_settings['fingerprint_width'])
|
||||||
|
elif preprocess == 'micro':
|
||||||
|
if not frontend_op:
|
||||||
|
raise Exception(
|
||||||
|
'Micro frontend op is currently not available when running TensorFlow'
|
||||||
|
' directly from Python, you need to build and run through Bazel, for'
|
||||||
|
' example'
|
||||||
|
' `bazel run tensorflow/examples/speech_commands:freeze_graph`'
|
||||||
|
)
|
||||||
|
sample_rate = model_settings['sample_rate']
|
||||||
|
window_size_ms = (model_settings['window_size_samples'] *
|
||||||
|
1000) / sample_rate
|
||||||
|
window_step_ms = (model_settings['window_stride_samples'] *
|
||||||
|
1000) / sample_rate
|
||||||
|
int16_input = tf.cast(
|
||||||
|
tf.multiply(decoded_sample_data.audio, 32767), tf.int16)
|
||||||
|
micro_frontend = frontend_op.audio_microfrontend(
|
||||||
|
int16_input,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
window_size=window_size_ms,
|
||||||
|
window_step=window_step_ms,
|
||||||
|
num_channels=model_settings['fingerprint_width'],
|
||||||
|
out_scale=1,
|
||||||
|
out_type=tf.float32)
|
||||||
|
fingerprint_input = tf.multiply(micro_frontend, (10.0 / 256.0))
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or'
|
raise Exception('Unknown preprocess mode "%s" (should be "mfcc",'
|
||||||
' "average")' % (preprocess))
|
' "average", or "micro")' % (preprocess))
|
||||||
|
|
||||||
fingerprint_size = model_settings['fingerprint_size']
|
fingerprint_size = model_settings['fingerprint_size']
|
||||||
reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size])
|
reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size])
|
||||||
|
@ -65,6 +65,24 @@ class FreezeTest(test.TestCase):
|
|||||||
ops = [node.op for node in sess.graph_def.node]
|
ops = [node.op for node in sess.graph_def.node]
|
||||||
self.assertEqual(0, ops.count('Mfcc'))
|
self.assertEqual(0, ops.count('Mfcc'))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testCreateInferenceGraphWithMicro(self):
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
freeze.create_inference_graph(
|
||||||
|
wanted_words='a,b,c,d',
|
||||||
|
sample_rate=16000,
|
||||||
|
clip_duration_ms=1000.0,
|
||||||
|
clip_stride_ms=30.0,
|
||||||
|
window_size_ms=30.0,
|
||||||
|
window_stride_ms=10.0,
|
||||||
|
feature_bin_count=40,
|
||||||
|
model_architecture='conv',
|
||||||
|
preprocess='micro')
|
||||||
|
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
|
||||||
|
self.assertIsNotNone(
|
||||||
|
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
|
||||||
|
self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testFeatureBinCount(self):
|
def testFeatureBinCount(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
|
@ -37,6 +37,13 @@ from tensorflow.python.ops import io_ops
|
|||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
# If it's available, load the specialized feature generator. If this doesn't
|
||||||
|
# work, try building with bazel instead of running the Python script directly.
|
||||||
|
try:
|
||||||
|
from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op # pylint:disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
frontend_op = None
|
||||||
|
|
||||||
MAX_NUM_WAVS_PER_CLASS = 2**27 - 1 # ~134M
|
MAX_NUM_WAVS_PER_CLASS = 2**27 - 1 # ~134M
|
||||||
SILENCE_LABEL = '_silence_'
|
SILENCE_LABEL = '_silence_'
|
||||||
SILENCE_INDEX = 0
|
SILENCE_INDEX = 0
|
||||||
@ -169,9 +176,12 @@ def get_features_range(model_settings):
|
|||||||
elif model_settings['preprocess'] == 'mfcc':
|
elif model_settings['preprocess'] == 'mfcc':
|
||||||
features_min = -247.0
|
features_min = -247.0
|
||||||
features_max = 30.0
|
features_max = 30.0
|
||||||
|
elif model_settings['preprocess'] == 'micro':
|
||||||
|
features_min = 0.0
|
||||||
|
features_max = 26.0
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or'
|
raise Exception('Unknown preprocess mode "%s" (should be "mfcc",'
|
||||||
' "average")' % (model_settings['preprocess']))
|
' "average", or "micro")' % (model_settings['preprocess']))
|
||||||
return features_min, features_max
|
return features_min, features_max
|
||||||
|
|
||||||
|
|
||||||
@ -377,6 +387,7 @@ class AudioProcessor(object):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the preprocessing mode isn't recognized.
|
ValueError: If the preprocessing mode isn't recognized.
|
||||||
|
Exception: If the preprocessor wasn't compiled in.
|
||||||
"""
|
"""
|
||||||
with tf.get_default_graph().name_scope('data'):
|
with tf.get_default_graph().name_scope('data'):
|
||||||
desired_samples = model_settings['desired_samples']
|
desired_samples = model_settings['desired_samples']
|
||||||
@ -442,9 +453,36 @@ class AudioProcessor(object):
|
|||||||
dct_coefficient_count=model_settings['fingerprint_width'])
|
dct_coefficient_count=model_settings['fingerprint_width'])
|
||||||
tf.summary.image(
|
tf.summary.image(
|
||||||
'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
|
'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
|
||||||
|
elif model_settings['preprocess'] == 'micro':
|
||||||
|
if not frontend_op:
|
||||||
|
raise Exception(
|
||||||
|
'Micro frontend op is currently not available when running'
|
||||||
|
' TensorFlow directly from Python, you need to build and run'
|
||||||
|
' through Bazel'
|
||||||
|
)
|
||||||
|
sample_rate = model_settings['sample_rate']
|
||||||
|
window_size_ms = (model_settings['window_size_samples'] *
|
||||||
|
1000) / sample_rate
|
||||||
|
window_step_ms = (model_settings['window_stride_samples'] *
|
||||||
|
1000) / sample_rate
|
||||||
|
int16_input = tf.cast(tf.multiply(background_clamp, 32768), tf.int16)
|
||||||
|
micro_frontend = frontend_op.audio_microfrontend(
|
||||||
|
int16_input,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
window_size=window_size_ms,
|
||||||
|
window_step=window_step_ms,
|
||||||
|
num_channels=model_settings['fingerprint_width'],
|
||||||
|
out_scale=1,
|
||||||
|
out_type=tf.float32)
|
||||||
|
self.output_ = tf.multiply(micro_frontend, (10.0 / 256.0))
|
||||||
|
tf.summary.image(
|
||||||
|
'micro',
|
||||||
|
tf.expand_dims(tf.expand_dims(self.output_, -1), 0),
|
||||||
|
max_outputs=1)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
|
raise ValueError(
|
||||||
' "average")' % (model_settings['preprocess']))
|
'Unknown preprocess mode "%s" (should be "mfcc", '
|
||||||
|
' "average", or "micro")' % (model_settings['preprocess']))
|
||||||
|
|
||||||
# Merge all the summaries and write them out to /tmp/retrain_logs (by
|
# Merge all the summaries and write them out to /tmp/retrain_logs (by
|
||||||
# default)
|
# default)
|
||||||
|
@ -202,6 +202,10 @@ class InputDataTest(test.TestCase):
|
|||||||
def testGetDataMfcc(self):
|
def testGetDataMfcc(self):
|
||||||
self._runGetDataTest("mfcc", 30)
|
self._runGetDataTest("mfcc", 30)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGetDataMicro(self):
|
||||||
|
self._runGetDataTest("micro", 20)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testGetUnprocessedData(self):
|
def testGetUnprocessedData(self):
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
|
@ -71,9 +71,12 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
|
|||||||
elif preprocess == 'mfcc':
|
elif preprocess == 'mfcc':
|
||||||
average_window_width = -1
|
average_window_width = -1
|
||||||
fingerprint_width = feature_bin_count
|
fingerprint_width = feature_bin_count
|
||||||
|
elif preprocess == 'micro':
|
||||||
|
average_window_width = -1
|
||||||
|
fingerprint_width = feature_bin_count
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
|
raise ValueError('Unknown preprocess mode "%s" (should be "mfcc",'
|
||||||
' "average")' % (preprocess))
|
' "average", or "micro")' % (preprocess))
|
||||||
fingerprint_size = fingerprint_width * spectrogram_length
|
fingerprint_size = fingerprint_width * spectrogram_length
|
||||||
return {
|
return {
|
||||||
'desired_samples': desired_samples,
|
'desired_samples': desired_samples,
|
||||||
|
@ -446,7 +446,7 @@ if __name__ == '__main__':
|
|||||||
'--preprocess',
|
'--preprocess',
|
||||||
type=str,
|
type=str,
|
||||||
default='mfcc',
|
default='mfcc',
|
||||||
help='Spectrogram processing mode. Can be "mfcc" or "average"')
|
help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')
|
||||||
|
|
||||||
FLAGS, unparsed = parser.parse_known_args()
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||||
|
@ -56,7 +56,7 @@ def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
|
|||||||
window_stride_ms: How far to move in time between spectogram timeslices.
|
window_stride_ms: How far to move in time between spectogram timeslices.
|
||||||
feature_bin_count: How many bins to use for the feature fingerprint.
|
feature_bin_count: How many bins to use for the feature fingerprint.
|
||||||
quantize: Whether to train the model for eight-bit deployment.
|
quantize: Whether to train the model for eight-bit deployment.
|
||||||
preprocess: Spectrogram processing mode. Can be "mfcc" or "average".
|
preprocess: Spectrogram processing mode; "mfcc", "average" or "micro".
|
||||||
input_wav: Path to the audio WAV file to read.
|
input_wav: Path to the audio WAV file to read.
|
||||||
output_c_file: Where to save the generated C source file.
|
output_c_file: Where to save the generated C source file.
|
||||||
"""
|
"""
|
||||||
@ -86,14 +86,15 @@ def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
|
|||||||
f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms)
|
f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms)
|
||||||
f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count)
|
f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count)
|
||||||
if quantize:
|
if quantize:
|
||||||
f.write(' * --quantize \\\n')
|
f.write(' * --quantize=1 \\\n')
|
||||||
f.write(' * --preprocess="%s" \\\n' % preprocess)
|
f.write(' * --preprocess="%s" \\\n' % preprocess)
|
||||||
f.write(' * --input_wav="%s" \\\n' % input_wav)
|
f.write(' * --input_wav="%s" \\\n' % input_wav)
|
||||||
f.write(' * --output_c_file="%s" \\\n' % output_c_file)
|
f.write(' * --output_c_file="%s" \\\n' % output_c_file)
|
||||||
f.write(' */\n\n')
|
f.write(' */\n\n')
|
||||||
f.write('const int g_%s_width = %d;\n' % (variable_base, features.shape[2]))
|
f.write('const int g_%s_width = %d;\n' %
|
||||||
f.write(
|
(variable_base, model_settings['fingerprint_width']))
|
||||||
'const int g_%s_height = %d;\n' % (variable_base, features.shape[1]))
|
f.write('const int g_%s_height = %d;\n' %
|
||||||
|
(variable_base, model_settings['spectrogram_length']))
|
||||||
if quantize:
|
if quantize:
|
||||||
features_min, features_max = input_data.get_features_range(model_settings)
|
features_min, features_max = input_data.get_features_range(model_settings)
|
||||||
f.write('const unsigned char g_%s_data[] = {' % variable_base)
|
f.write('const unsigned char g_%s_data[] = {' % variable_base)
|
||||||
@ -108,7 +109,7 @@ def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
|
|||||||
quantized_value = 255
|
quantized_value = 255
|
||||||
if i == 0:
|
if i == 0:
|
||||||
f.write('\n ')
|
f.write('\n ')
|
||||||
f.write('%d, ' % quantized_value)
|
f.write('%d, ' % (quantized_value))
|
||||||
i = (i + 1) % 10
|
i = (i + 1) % 10
|
||||||
else:
|
else:
|
||||||
f.write('const float g_%s_data[] = {\n' % variable_base)
|
f.write('const float g_%s_data[] = {\n' % variable_base)
|
||||||
@ -168,7 +169,7 @@ if __name__ == '__main__':
|
|||||||
'--preprocess',
|
'--preprocess',
|
||||||
type=str,
|
type=str,
|
||||||
default='mfcc',
|
default='mfcc',
|
||||||
help='Spectrogram processing mode. Can be "mfcc" or "average"')
|
help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--input_wav',
|
'--input_wav',
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -66,6 +66,22 @@ class WavToFeaturesTest(test.TestCase):
|
|||||||
content = f.read()
|
content = f.read()
|
||||||
self.assertTrue(b"const unsigned char g_input_data" in content)
|
self.assertTrue(b"const unsigned char g_input_data" in content)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testWavToFeaturesMicro(self):
|
||||||
|
tmp_dir = self.get_temp_dir()
|
||||||
|
wav_dir = os.path.join(tmp_dir, "wavs")
|
||||||
|
os.mkdir(wav_dir)
|
||||||
|
self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
|
||||||
|
input_file_path = os.path.join(tmp_dir, "input.wav")
|
||||||
|
output_file_path = os.path.join(tmp_dir, "output.c")
|
||||||
|
wav_data = self._getWavData()
|
||||||
|
self._saveTestWavFile(input_file_path, wav_data)
|
||||||
|
wav_to_features.wav_to_features(16000, 1000, 10, 10, 40, True, "micro",
|
||||||
|
input_file_path, output_file_path)
|
||||||
|
with open(output_file_path, "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
self.assertIn(b"const unsigned char g_input_data", content)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user