Speech keyword detector tutorial
Adds a basic training script for a simple audio model to our examples. See third_party/docs_src/tutorials/audio_recognition.md for full documentation PiperOrigin-RevId: 165025732
This commit is contained in:
parent
e9a8d75bc4
commit
0c6fd1703e
@ -370,6 +370,7 @@ filegroup(
|
||||
"//tensorflow/examples/label_image:all_files",
|
||||
"//tensorflow/examples/learn:all_files",
|
||||
"//tensorflow/examples/saved_model:all_files",
|
||||
"//tensorflow/examples/speech_commands:all_files",
|
||||
"//tensorflow/examples/tutorials/estimators:all_files",
|
||||
"//tensorflow/examples/tutorials/mnist:all_files",
|
||||
"//tensorflow/examples/tutorials/word2vec:all_files",
|
||||
|
@ -28,6 +28,7 @@ tf_custom_op_py_library(
|
||||
"python/framework/tensor_util.py",
|
||||
"python/ops/__init__.py",
|
||||
"python/ops/arg_scope.py",
|
||||
"python/ops/audio_ops.py",
|
||||
"python/ops/checkpoint_ops.py",
|
||||
"python/ops/ops.py",
|
||||
"python/ops/prettyprint_ops.py",
|
||||
@ -50,6 +51,7 @@ tf_custom_op_py_library(
|
||||
":gen_variable_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:audio_ops_gen",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
|
36
tensorflow/contrib/framework/python/ops/audio_ops.py
Normal file
36
tensorflow/contrib/framework/python/ops/audio_ops.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# pylint: disable=g-short-docstring-punctuation
|
||||
"""Audio processing and decoding ops.
|
||||
|
||||
@@decode_wav
|
||||
@@encode_wav
|
||||
@@audio_spectrogram
|
||||
@@mfcc
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_audio_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
remove_undocumented(__name__, [])
|
@ -118,6 +118,17 @@ Status ReadValue(const string& data, T* value, int* offset) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadString(const string& data, int expected_length, string* value,
|
||||
int* offset) {
|
||||
const int new_offset = *offset + expected_length;
|
||||
if (new_offset > data.size()) {
|
||||
return errors::InvalidArgument("Data too short when trying to read string");
|
||||
}
|
||||
*value = string(data.begin() + *offset, data.begin() + new_offset);
|
||||
*offset = new_offset;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
|
||||
@ -254,17 +265,33 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string,
|
||||
// Skip over this unused section.
|
||||
offset += 2;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ExpectText(wav_string, kDataChunkId, &offset));
|
||||
uint32 data_size;
|
||||
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &data_size, &offset));
|
||||
*sample_count = data_size / bytes_per_sample;
|
||||
const uint32 data_count = *sample_count * *channel_count;
|
||||
float_values->resize(data_count);
|
||||
for (int i = 0; i < data_count; ++i) {
|
||||
int16 single_channel_value = 0;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadValue<int16>(wav_string, &single_channel_value, &offset));
|
||||
(*float_values)[i] = Int16SampleToFloat(single_channel_value);
|
||||
|
||||
bool was_data_found = false;
|
||||
while (offset < wav_string.size()) {
|
||||
string chunk_id;
|
||||
TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
|
||||
uint32 chunk_size;
|
||||
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
|
||||
if (chunk_id == kDataChunkId) {
|
||||
if (was_data_found) {
|
||||
return errors::InvalidArgument("More than one data chunk found in WAV");
|
||||
}
|
||||
was_data_found = true;
|
||||
*sample_count = chunk_size / bytes_per_sample;
|
||||
const uint32 data_count = *sample_count * *channel_count;
|
||||
float_values->resize(data_count);
|
||||
for (int i = 0; i < data_count; ++i) {
|
||||
int16 single_channel_value = 0;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadValue<int16>(wav_string, &single_channel_value, &offset));
|
||||
(*float_values)[i] = Int16SampleToFloat(single_channel_value);
|
||||
}
|
||||
} else {
|
||||
offset += chunk_size;
|
||||
}
|
||||
}
|
||||
if (!was_data_found) {
|
||||
return errors::InvalidArgument("No data chunk found in WAV");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ Status DecodeWavShapeFn(InferenceContext* c) {
|
||||
Status EncodeWavShapeFn(InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
c->set_output(0, c->Scalar());
|
||||
return Status::OK();
|
||||
}
|
||||
@ -104,7 +104,7 @@ Status MfccShapeFn(InferenceContext* c) {
|
||||
ShapeHandle spectrogram;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram));
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
|
||||
int32 dct_coefficient_count;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
551
tensorflow/docs_src/tutorials/audio_recognition.md
Normal file
551
tensorflow/docs_src/tutorials/audio_recognition.md
Normal file
@ -0,0 +1,551 @@
|
||||
# How to Train a Simple Audio Recognition Network
|
||||
|
||||
This tutorial will show you how to build a basic speech recognition network that
|
||||
recognizes ten different words. It's important to know that real speech and
|
||||
audio recognition systems are much more complex, but like MNIST for images, it
|
||||
should give you a basic understanding of the techniques involved. Once you've
|
||||
completed this tutorial, you'll have a model that tries to classify a one second
|
||||
audio clip as either silence, an unknown word, "yes", "no", "up", "down",
|
||||
"left", "right", "on", "off", "stop", or "go". You'll also be able to take this
|
||||
model and run it in an Android application.
|
||||
|
||||
## Preparation
|
||||
|
||||
You should make sure you have TensorFlow installed, and since the script
|
||||
downloads over 1GB of training data, you'll need a good internet connection and
|
||||
enough free space on your machine. The training process itself can take several
|
||||
hours, so make sure you have a machine available for that long.
|
||||
|
||||
## Training
|
||||
|
||||
To begin the training process, go to the TensorFlow source tree and run:
|
||||
|
||||
```bash
|
||||
python tensorflow/examples/speech_commands/train.py
|
||||
```
|
||||
|
||||
The script will start off by downloading the [Speech Commands
|
||||
dataset](https://download.tensorflow.org/data/speech_commands_v0.01.tar.gz),
|
||||
which consists of 65,000 WAVE audio files of people saying thirty different
|
||||
words. This data was collected by Google and released under a CC BY license, and
|
||||
you can help improve it by [contributing five minutes of your own
|
||||
voice](https://aiyprojects.withgoogle.com/open_speech_recording). The archive is
|
||||
over 1GB, so this part may take a while, but you should see progress logs, and
|
||||
once it's been downloaded once you won't need to do this step again.
|
||||
|
||||
Once the downloading has completed, you'll see logging information that looks
|
||||
like this:
|
||||
|
||||
```
|
||||
I0730 16:53:44.766740 55030 train.py:176] Training from step: 1
|
||||
I0730 16:53:47.289078 55030 train.py:217] Step #1: rate 0.001000, accuracy 7.0%, cross entropy 2.611571
|
||||
```
|
||||
|
||||
This shows that the initialization process is done and the training loop has
|
||||
begun. You'll see that it outputs information for every training step. Here's a
|
||||
break down of what it means:
|
||||
|
||||
`Step #1` shows that we're on the first step of the training loop. In this case
|
||||
there are going to be 18,000 steps in total, so you can look at the step number
|
||||
to get an idea of how close it is to finishing.
|
||||
|
||||
`rate 0.001000` is the learning rate that's controlling the speed of the
|
||||
network's weight updates. Early on this is a comparatively high number (0.001),
|
||||
but for later training cycles it will be reduced 10x, to 0.0001.
|
||||
|
||||
`accuracy 7.0%` is the how many classes were correctly predicted on this
|
||||
training step. This value will often fluctuate a lot, but should increase on
|
||||
average as training progresses. The model outputs an array of numbers, one for
|
||||
each label, and each number is the predicted likelihood of the input being that
|
||||
class. The predicted label is picked by choosing the entry with the highest
|
||||
score. The scores are always between zero and one, with higher values
|
||||
representing more confidence in the result.
|
||||
|
||||
`cross entropy 2.611571` is the result of the loss function that we're using to
|
||||
guide the training process. This is a score that's obtained by comparing the
|
||||
vector of scores from the current training run to the correct labels, and this
|
||||
should trend downwards during training.
|
||||
|
||||
After a hundred steps, you should see a line like this:
|
||||
|
||||
`I0730 16:54:41.813438 55030 train.py:252] Saving to
|
||||
"/tmp/speech_commands_train/conv.ckpt-100"`
|
||||
|
||||
This is saving out the current trained weights to a checkpoint file. If your
|
||||
training script gets interrupted, you can look for the last saved checkpoint and
|
||||
then restart the script with
|
||||
`--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-100` as a command line
|
||||
argument to start from that point.
|
||||
|
||||
## Confusion Matrix
|
||||
|
||||
After four hundred steps, this information will be logged:
|
||||
|
||||
```
|
||||
I0730 16:57:38.073667 55030 train.py:243] Confusion Matrix:
|
||||
[[258 0 0 0 0 0 0 0 0 0 0 0]
|
||||
[ 7 6 26 94 7 49 1 15 40 2 0 11]
|
||||
[ 10 1 107 80 13 22 0 13 10 1 0 4]
|
||||
[ 1 3 16 163 6 48 0 5 10 1 0 17]
|
||||
[ 15 1 17 114 55 13 0 9 22 5 0 9]
|
||||
[ 1 1 6 97 3 87 1 12 46 0 0 10]
|
||||
[ 8 6 86 84 13 24 1 9 9 1 0 6]
|
||||
[ 9 3 32 112 9 26 1 36 19 0 0 9]
|
||||
[ 8 2 12 94 9 52 0 6 72 0 0 2]
|
||||
[ 16 1 39 74 29 42 0 6 37 9 0 3]
|
||||
[ 15 6 17 71 50 37 0 6 32 2 1 9]
|
||||
[ 11 1 6 151 5 42 0 8 16 0 0 20]]
|
||||
```
|
||||
|
||||
The first section is a [confusion
|
||||
matrix](https://www.tensorflow.org/api_docs/python/tf/confusion_matrix). To
|
||||
understand what it means, you first need to know the labels being used, which in
|
||||
this case are "_silence_", "_unknown_", "yes", "no", "up", "down", "left",
|
||||
"right", "on", "off", "stop", and "go". Each column represents a set of samples
|
||||
that were predicted to be each label, so the first column represents all the
|
||||
clips that were predicted to be silence, the second all those that were
|
||||
predicted to be unknown words, the third "yes", and so on.
|
||||
|
||||
Each row represents clips by their correct, ground truth labels. The first row
|
||||
is all the clips that were silence, the second clips that were unknown words,
|
||||
the third "yes", etc.
|
||||
|
||||
This matrix can be more useful than just a single accuracy score because it
|
||||
gives a good summary of what mistakes the network is making. In this example you
|
||||
can see that all of the entries in the first row are zero, apart from the
|
||||
initial one. Because the first row is all the clips that are actually silence,
|
||||
this means that none of them were mistakenly labeled as words, so we have no
|
||||
false negatives for silence. This shows the network is already getting pretty
|
||||
good at distinguishing silence from words.
|
||||
|
||||
If we look down the first column though, we see a lot of non-zero values. The
|
||||
column represents all the clips that were predicted to be silence, so positive
|
||||
numbers outside of the first cell are errors. This means that some clips of real
|
||||
spoken words are actually being predicted to be silence, so we do have quite a
|
||||
few false positives.
|
||||
|
||||
A perfect model would produce a confusion matrix where all of the entries were
|
||||
zero apart from a diagonal line through the center. Spotting deviations from
|
||||
that pattern can help you figure out how the model is most easily confused, and
|
||||
once you've identified the problems you can address them by adding more data or
|
||||
cleaning up categories.
|
||||
|
||||
## Validation
|
||||
|
||||
After the confusion matrix, you should see a line like this:
|
||||
|
||||
`I0730 16:57:38.073777 55030 train.py:245] Step 400: Validation accuracy = 26.3%
|
||||
(N=3093)`
|
||||
|
||||
It's good practice to separate your data set into three categories. The largest
|
||||
(in this case roughly 80% of the data) is used for training the network, a
|
||||
smaller set (10% here, known as "validation") is reserved for evaluation of the
|
||||
accuracy during training, and another set (the last 10%, "testing") is used to
|
||||
evaluate the accuracy once after the training is complete.
|
||||
|
||||
The reason for this split is that there's always a danger that networks will
|
||||
start memorizing their inputs during training. By keeping the validation set
|
||||
separate, you can ensure that the model works with data it's never seen before.
|
||||
The testing set is an additional safeguard to make sure that you haven't just
|
||||
been tweaking your model in a way that happens to work for both the training and
|
||||
validation sets, but not a broader range of inputs.
|
||||
|
||||
The training script automatically separates the data set into these three
|
||||
categories, and the logging line above shows the accuracy of model when run on
|
||||
the validation set. Ideally, this should stick fairly close to the training
|
||||
accuracy. If the training accuracy increases but the validation doesn't, that's
|
||||
a sign that overfitting is occurring, and your model is only learning things
|
||||
about the training clips, not broader patterns that generalize.
|
||||
|
||||
## Tensorboard
|
||||
|
||||
A good way to visualize how the training is progressing is using Tensorboard. By
|
||||
default, the script saves out events to /tmp/retrain_logs, and you can load
|
||||
these by running:
|
||||
|
||||
`tensorboard --logdir /tmp/retrain_logs`
|
||||
|
||||
Then navigate to [http://localhost:6006](http://localhost:6006) in your browser,
|
||||
and you'll see charts and graphs showing your models progress.
|
||||
|
||||
<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="https://storage.googleapis.com/download.tensorflow.org/example_images/speech_commands_tensorflow.png"/>
|
||||
</div>
|
||||
|
||||
## Training Finished
|
||||
|
||||
After a few hours of training (depending on your machine's speed), the script
|
||||
should have completed all 18,000 steps. It will print out a final confusion
|
||||
matrix, along with an accuracy score, all run on the testing set. With the
|
||||
default settings, you should see an accuracy of between 85% and 90%.
|
||||
|
||||
Because audio recognition is particularly useful on mobile devices, next we'll
|
||||
export it to a compact format that's easy to work with on those platforms. To do
|
||||
that, run this command line:
|
||||
|
||||
```
|
||||
python tensorflow/examples/speech_commands/freeze.py \
|
||||
--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-18000 \
|
||||
--output_file=/tmp/my_frozen_graph.pb
|
||||
```
|
||||
|
||||
Once the frozen model has been created, you can test it with the `label_wav.py`
|
||||
script, like this:
|
||||
|
||||
```
|
||||
python tensorflow/examples/speech_commands/label_wav.py \
|
||||
--graph=/tmp/my_frozen_graph.pb \
|
||||
--labels=/tmp/speech_commands_train/conv_labels.txt \
|
||||
--wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav
|
||||
```
|
||||
|
||||
This should print out three labels:
|
||||
|
||||
```
|
||||
left (score = 0.81477)
|
||||
right (score = 0.14139)
|
||||
_unknown_ (score = 0.03808)
|
||||
```
|
||||
|
||||
Hopefully "left" is the top score since that's the correct label, but since the
|
||||
training is random it may not for the first file you try. Experiment with some
|
||||
of the other .wav files in that same folder to see how well it does.
|
||||
|
||||
The scores are between zero and one, and higher values mean the model is more
|
||||
confident in its prediction.
|
||||
|
||||
## How does this Model Work?
|
||||
|
||||
The architecture used in this tutorial is based on some described in the paper
|
||||
[Convolutional Neural Networks for Small-footprint Keyword
|
||||
Spotting](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf).
|
||||
It was chosen because it's comparatively simple, quick to train, and easy to
|
||||
understand, rather than being state of the art. There are lots of different
|
||||
approaches to building neural network models to work with audio, including
|
||||
[recurrent networks](https://svds.com/tensorflow-rnn-tutorial/) or [dilated
|
||||
(atrous)
|
||||
convolutions](https://deepmind.com/blog/wavenet-generative-model-raw-audio/).
|
||||
This tutorial is based on the kind of convolutional network that will feel very
|
||||
familiar to anyone who's worked with image recognition. That may seem surprising
|
||||
at first though, since audio is inherently a one-dimensional continuous signal
|
||||
across time, not a 2D spatial problem.
|
||||
|
||||
We solve that issue by defining a window of time we believe our spoken words
|
||||
should fit into, and converting the audio signal in that window into an image.
|
||||
This is done by grouping the incoming audio samples into short segments, just a
|
||||
few milliseconds long, and calculating the strength of the frequencies across a
|
||||
set of bands. Each set of frequency strengths from a segment is treated as a
|
||||
vector of numbers, and those vectors are arranged in time order to form a
|
||||
two-dimensional array. This array of values can then be treated like a
|
||||
single-channel image, and is known as a
|
||||
[spectrogram](https://en.wikipedia.org/wiki/Spectrogram). If you want to view
|
||||
what kind of image an audio sample produces, you can run the `wav_to_spectrogram
|
||||
tool:
|
||||
|
||||
```
|
||||
bazel run tensorflow/examples/wav_to_spectrogram:wav_to_spectrogram -- \
|
||||
--input_wav=/tmp/speech_dataset/happy/ab00c4b2_nohash_0.wav \
|
||||
--output_png=/tmp/spectrogram.png
|
||||
```
|
||||
|
||||
If you open up `/tmp/spectrogram.png` you should see something like this:
|
||||
|
||||
<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="https://storage.googleapis.com/download.tensorflow.org/example_images/spectrogram.png"/>
|
||||
</div>
|
||||
|
||||
Because of TensorFlow's memory order, time in this image is increasing from top
|
||||
to bottom, with frequencies going from left to right, unlike the usual
|
||||
convention for spectrograms where time is left to right. You should be able to
|
||||
see a couple of distinct parts, with the first syllable "Ha" distinct from
|
||||
"ppy".
|
||||
|
||||
Because the human ear is more sensitive to some frequencies than others, it's
|
||||
been traditional in speech recognition to do further processing to this
|
||||
representation to turn it into a set of [Mel-Frequency Cepstral
|
||||
Coefficients](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum), or MFCCs
|
||||
for short. This is also a two-dimensional, one-channel representation so it can
|
||||
be treated like an image too. If you're targeting general sounds rather than
|
||||
speech you may find you can skip this step and operate directly on the
|
||||
spectrograms.
|
||||
|
||||
The image that's produced by these processing steps is then fed into a
|
||||
multi-layer convolutional neural network, with a fully-connected layer followed
|
||||
by a softmax at the end. You can see the definition of this portion in
|
||||
[tensorflow/examples/speech_commands/models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py).
|
||||
|
||||
## Streaming Accuracy
|
||||
|
||||
Most audio recognition applications need to run on a continuous stream of audio,
|
||||
rather than on individual clips. A typical way to use a model in this
|
||||
environment is to apply it repeatedly at different offsets in time and average
|
||||
the results over a short window to produce a smoothed prediction. If you think
|
||||
of the input as an image, it's continuously scrolling along the time axis. The
|
||||
words we want to recognize can start at any time, so we need to take a series of
|
||||
snapshots to have a chance of having an alignment that captures most of the
|
||||
utterance in the time window we feed into the model. If we sample at a high
|
||||
enough rate, then we have a good chance of capturing the word in multiple
|
||||
windows, so averaging the results improves the overall confidence of the
|
||||
prediction.
|
||||
|
||||
For an example of how you can use your model on streaming data, you can look at
|
||||
[test_streaming_accuracy.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/).
|
||||
This uses the
|
||||
[RecognizeCommands](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/recognize_commands.h)
|
||||
class to run through a long-form input audio, try to spot words, and compare
|
||||
those predictions against a ground truth list of labels and times. This makes it
|
||||
a good example of applying a model to a stream of audio signals over time.
|
||||
|
||||
You'll need a long audio file to test it against, along with labels showing
|
||||
where each word was spoken. If you don't want to record one yourself, you can
|
||||
generate some synthetic test data using the `generate_streaming_test_wav`
|
||||
utility. By default this will create a ten minute .wav file with words roughly
|
||||
every three seconds, and a text file containing the ground truth of when each
|
||||
word was spoken. These words are pulled from the test portion of your current
|
||||
dataset, mixed in with background noise. To run it, use:
|
||||
|
||||
```
|
||||
bazel run tensorflow/examples/speech_commands:generate_streaming_test_wav
|
||||
```
|
||||
|
||||
This will save a .wav file to `/tmp/speech_commands_train/streaming_test.wav`,
|
||||
and a text file listing the labels to
|
||||
`/tmp/speech_commands_train/streaming_test_labels.txt`. You can then run
|
||||
accuracy testing with:
|
||||
|
||||
```
|
||||
bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \
|
||||
--graph=/tmp/my_frozen_graph.pb \
|
||||
--labels=/tmp/speech_commands_train/conv_labels.txt \
|
||||
--wav=/tmp/speech_commands_train/streaming_test.wav \
|
||||
--ground_truth=/tmp/speech_commands_train/streaming_test_labels.txt \
|
||||
--verbose
|
||||
```
|
||||
|
||||
This will output information about the number of words correctly matched, how
|
||||
many were given the wrong labels, and how many times the model triggered when
|
||||
there was no real word spoken. There are various parameters that control how the
|
||||
signal averaging works, including `--average_window_ms` which sets the length of
|
||||
time to average results over, `--sample_stride_ms` which is the time between
|
||||
applications of the model, `--suppression_ms` which stops subsequent word
|
||||
detections from triggering for a certain time after an initial one is found, and
|
||||
`--detection_threshold`, which controls how high the average score must be
|
||||
before it's considered a solid result.
|
||||
|
||||
You'll see that the streaming accuracy outputs three numbers, rather than just
|
||||
the one metric used in training. This is because different applications have
|
||||
varying requirements, with some being able to tolerate frequent incorrect
|
||||
results as long as real words are found (high recall), while others very focused
|
||||
on ensuring the predicted labels are highly likely to be correct even if some
|
||||
aren't detected (high precision). The numbers from the tool give you an idea of
|
||||
how your model will perform in an application, and you can try tweaking the
|
||||
signal averaging parameters to tune it to give the kind of performance you want.
|
||||
To understand what the right parameters are for your application, you can look
|
||||
at generating an [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)
|
||||
to help you understand the tradeoffs.
|
||||
|
||||
## RecognizeCommands
|
||||
|
||||
The streaming accuracy tool uses a simple decoder contained in a small
|
||||
C++ class called
|
||||
[RecognizeCommands](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/recognize_commands.h).
|
||||
This class is fed the output of running the TensorFlow model over time, it
|
||||
averages the signals, and returns information about a label when it has enough
|
||||
evidence to think that a recognized word has been found. The implementation is
|
||||
fairly small, just keeping track of the last few predictions and averaging them,
|
||||
so it's easy to port to other platforms and languages as needed. For example,
|
||||
it's convenient to do something similar at the Java level on Android, or Python
|
||||
on the Raspberry Pi. As long as these implementations share the same logic, you
|
||||
can tune the parameters that control the averaging using the streaming test
|
||||
tool, and then transfer them over to your application to get similar results.
|
||||
|
||||
## Advanced Training
|
||||
|
||||
The defaults for the training script are designed to produce good end to end
|
||||
results in a comparatively small file, but there are a lot of options you can
|
||||
change to customize the results for your own requirements.
|
||||
|
||||
### Custom Training Data
|
||||
|
||||
By default the script will download the [Speech Commands
|
||||
dataset](https://download.tensorflow.org/data/speech_commands_v0.01.tgz), but
|
||||
you can also supply your own training data. To train on your own data, you
|
||||
should make sure that you have at least several hundred recordings of each sound
|
||||
you would like to recognize, and arrange them into folders by class. For
|
||||
example, if you were trying to recognize dog barks from cat miaows, you would
|
||||
create a root folder called `animal_sounds`, and then within that two
|
||||
sub-folders called `bark` and `miaow`. You would then organize your audio files
|
||||
into the appropriate folders.
|
||||
|
||||
To point the script to your new audio files, you'll need to set `--data_url=` to
|
||||
disable downloading of the Speech Commands dataset, and
|
||||
`--data_dir=/your/data/folder/` to find the files you've just created.
|
||||
|
||||
The files themselves should be 16-bit little-endian PCM-encoded WAVE format. The
|
||||
sample rate defaults to 16,000, but as long as all your audio is consistently
|
||||
the same rate (the script doesn't support resampling) you can change this with
|
||||
the `--sample_rate` argument. The clips should also all be roughly the same
|
||||
duration. The default expected duration is one second, but you can set this with
|
||||
the `--clip_duration_ms` flag. If you have clips with variable amounts of
|
||||
silence at the start, you can look at word alignment tools to standardize them
|
||||
([here's a quick and dirty approach you can use
|
||||
too](https://petewarden.com/2017/07/17/a-quick-hack-to-align-single-word-audio-recordings/)).
|
||||
|
||||
One issue to watch out for is that you may have very similar repetitions of the
|
||||
same sounds in your dataset, and these can give misleading metrics if they're
|
||||
spread across your training, validation, and test sets. For example, the Speech
|
||||
Commands set has people repeating the same word multiple times. Each one of
|
||||
those repetitions is likely to be pretty close to the others, so if training was
|
||||
overfitting and memorizing one, it could perform unrealistically well when it
|
||||
saw a very similar copy in the test set. To avoid this danger, Speech Commands
|
||||
trys to ensure that all clips featuring the same word spoken by a single person
|
||||
are put into the same partition. Clips are assigned to training, test, or
|
||||
validation sets based on a hash of their filename, to ensure that the
|
||||
assignments remain steady even as new clips are added and avoid any training
|
||||
samples migrating into the other sets. To make sure that all a given speaker's
|
||||
words are in the same bucket, [the hashing
|
||||
function](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/input_data.py)
|
||||
ignores anything in a filename after '_nohash_' when calculating the
|
||||
assignments. This means that if you have file names like `pete_nohash_0.wav` and
|
||||
`pete_nohash_1.wav`, they're guaranteed to be in the same set.
|
||||
|
||||
### Unknown Class
|
||||
|
||||
It's likely that your application will hear sounds that aren't in your training
|
||||
set, and you'll want the model to indicate that it doesn't recognize the noise
|
||||
in those cases. To help the network learn what sounds to ignore, you need to
|
||||
provide some clips of audio that are neither of your classes. To do this, you'd
|
||||
create `quack`, `oink`, and `moo` subfolders and populate them with noises from
|
||||
other animals your users might encounter. The `--wanted_words` argument to the
|
||||
script defines which classes you care about, all the others mentioned in
|
||||
subfolder names will be used to populate an `_unknown_` class during training.
|
||||
The Speech Commands dataset has twenty words in its unknown classes, including
|
||||
the digits zero through nine and random names like "Sheila".
|
||||
|
||||
By default 10% of the training examples are picked from the unknown classes, but
|
||||
you can control this with the `--unknown_percentage` flag. Increasing this will
|
||||
make the model less likely to mistake unknown words for wanted ones, but making
|
||||
it too large can backfire as the model might decide it's safest to categorize
|
||||
all words as unknown!
|
||||
|
||||
### Background Noise
|
||||
|
||||
Real applications have to recognize audio even when there are other irrelevant
|
||||
sounds happening in the environment. To build a model that's robust to this kind
|
||||
of interference, we need to train against recorded audio with similar
|
||||
properties. The files in the Speech Commands dataset were captured on a variety
|
||||
of devices by users in many different environments, not in a studio, so that
|
||||
helps add some realism to the training. To add even more, you can mix in random
|
||||
segments of environmental audio to the training inputs. In the Speech Commands
|
||||
set there's a special folder called `_background_noise_` which contains
|
||||
minute-long WAVE files with white noise and recordings of machinery and everyday
|
||||
household activity.
|
||||
|
||||
Small snippets of these files are chosen at random and mixed at a low volume
|
||||
into clips during training. The loudness is also chosen randomly, and controlled
|
||||
by the `--background_volume` argument as a proportion where 0 is silence, and 1
|
||||
is full volume. Not all clips have background added, so the
|
||||
`--background_frequency` flag controls what proportion have them mixed in.
|
||||
|
||||
Your own application might operate in its own environment with different
|
||||
background noise patterns than these defaults, so you can supply your own audio
|
||||
clips in the `_background_noise_` folder. These should be the same sample rate
|
||||
as your main dataset, but much longer in duration so that a good set of random
|
||||
segments can be selected from them.
|
||||
|
||||
### Silence
|
||||
|
||||
In most cases the sounds you care about will be intermittent and so it's
|
||||
important to know when there's no matching audio. To support this, there's a
|
||||
special `_silence_` label that indicates when the model detects nothing
|
||||
interesting. Because there's never complete silence in real environments, we
|
||||
actually have to supply examples with quiet and irrelevant audio. For this, we
|
||||
reuse the `_background_noise_` folder that's also mixed in to real clips,
|
||||
pulling short sections of the audio data and feeding those in with the ground
|
||||
truth class of `_silence_`. By default 10% of the training data is supplied like
|
||||
this, but the `--silence_percentage` can be used to control the proportion. As
|
||||
with unknown words, setting this higher can weight the model results in favor of
|
||||
true positives for silence, at the expense of false negatives for words, but too
|
||||
large a proportion can cause it to fall into the trap of always guessing
|
||||
silence.
|
||||
|
||||
### Time Shifting
|
||||
|
||||
Adding in background noise is one way of distorting the training data in a
|
||||
realistic way to effectively increase the size of the dataset, and so increase
|
||||
overall accuracy, and time shifting is another. This involves a random offset in
|
||||
time of the training sample data, so that a small part of the start or end is
|
||||
cut off and the opposite section is padded with zeroes. This mimics the natural
|
||||
variations in starting time in the training data, and is controlled with the
|
||||
`--time_shift_ms` flag, which defaults to 100ms. Increasing this value will
|
||||
provide more variation, but at the risk of cutting off important parts of the
|
||||
audio. A related way of augmenting the data with realistic distortions is by
|
||||
using [time stretching and pitch scaling](https://en.wikipedia.org/wiki/Audio_time_stretching_and_pitch_scaling),
|
||||
but that's outside the scope of this tutorial.
|
||||
|
||||
## Customizing the Model
|
||||
|
||||
The default model used for this script is pretty large, taking over 800 million
|
||||
FLOPs for each inference and using 940,000 weight parameters. This runs at
|
||||
usable speeds on desktop machines or modern phones, but it involves too many
|
||||
calculations to run at interactive speeds on devices with more limited
|
||||
resources. To support these use cases, there's an alternative model available,
|
||||
based on the 'cnn-one-fstride4' architecture described in the [Convolutional
|
||||
Neural Networks for Small-footprint Keyword Spotting
|
||||
paper](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf).
|
||||
The number of weight parameters is about the same, but it only needs 11 million
|
||||
FLOPs to run one prediction, making it much faster.
|
||||
|
||||
To use this model, you can specify `--model_architecture=low_latency_conv` on
|
||||
the command line. You'll also need to update the training rates and the number
|
||||
of steps, so the full command will look like:
|
||||
|
||||
```
|
||||
python tensorflow/examples/speech_commands/train \
|
||||
--model_architecture=low_latency_conv \
|
||||
--how_many_training_steps=20000,6000 \
|
||||
--learning_rate=0.01,0.001
|
||||
```
|
||||
|
||||
This asks the script to train with a learning rate of 0.01 for 20,000 steps, and
|
||||
then do a fine-tuning pass of 6,000 steps with a 10x smaller rate.
|
||||
|
||||
If you want to experiment with customizing models, a good place to start is by
|
||||
tweaking the spectrogram creation parameters. This has the effect of altering
|
||||
the size of the input image to the model, and the creation code in
|
||||
[models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py)
|
||||
will adjust the number of computations and weights automatically to fit with
|
||||
different dimensions. If you make the input smaller, the model will need fewer
|
||||
computations to process it, so it can be a great way to trade off some accuracy
|
||||
for improved latency. The `--window_stride_ms` controls how far apart each
|
||||
frequency analysis sample is from the previous. If you increase this value, then
|
||||
fewer samples will be taken for a given duration, and the time axis of the input
|
||||
will shrink. The `--dct_coefficient_count` flag controls how many buckets are
|
||||
used for the frequency counting, so reducing this will shrink the input in the
|
||||
other dimension. The `--window_size_ms` argument doesn't affect the size, but
|
||||
does control how wide the area used to calculate the frequencies is for each
|
||||
sample. Reducing the duration of the training samples, controlled by
|
||||
`--clip_duration_ms`, can also help if the sounds you're looking for are short,
|
||||
since that also reduces the time dimension of the input. You'll need to make
|
||||
sure that all your training data contains the right audio in the initial portion
|
||||
of the clip though.
|
||||
|
||||
If you have an entirely different model in mind for your problem, you may find
|
||||
that you can plug it into
|
||||
[models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py)
|
||||
and have the rest of the script handle all of the preprocessing and training
|
||||
mechanics. You would add a new clause to `create_model`, looking for the name of
|
||||
your architecture and then calling a model creation function. This function is
|
||||
given the size of the spectrogram input, along with other model information, and
|
||||
is expected to create TensorFlow ops to read that in and produce an output
|
||||
prediction vector, and a placeholder to control the dropout rate. The rest of
|
||||
the script will handle integrating this model into a larger graph doing the
|
||||
input calculations and applying softmax and a loss function to train it.
|
||||
|
||||
One common problem when you're adjusting models and training hyper-parameters is
|
||||
that not-a-number values can creep in, thanks to numerical precision issues. In
|
||||
general you can solve these by reducing the magnitude of things like learning
|
||||
rates and weight initialization functions, but if they're persistent you can
|
||||
enable the `--check_nans` flag to track down the source of the errors. This will
|
||||
insert check ops between most regular operations in TensorFlow, and abort the
|
||||
training process with a useful error message when they're encountered.
|
258
tensorflow/examples/speech_commands/BUILD
Normal file
258
tensorflow/examples/speech_commands/BUILD
Normal file
@ -0,0 +1,258 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
py_library(
|
||||
name = "models",
|
||||
srcs = [
|
||||
"models.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "models_test",
|
||||
size = "small",
|
||||
srcs = ["models_test.py"],
|
||||
additional_deps = [
|
||||
":models",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "input_data",
|
||||
srcs = [
|
||||
"input_data.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "input_data_test",
|
||||
size = "small",
|
||||
srcs = ["input_data_test.py"],
|
||||
additional_deps = [
|
||||
":input_data",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "train",
|
||||
srcs = [
|
||||
"train.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":input_data",
|
||||
":models",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "freeze",
|
||||
srcs = [
|
||||
"freeze.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":input_data",
|
||||
":models",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "freeze_test",
|
||||
size = "small",
|
||||
srcs = ["freeze_test.py"],
|
||||
additional_deps = [
|
||||
":freeze",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "generate_streaming_test_wav",
|
||||
srcs = [
|
||||
"generate_streaming_test_wav.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":input_data",
|
||||
":models",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "generate_streaming_test_wav_test",
|
||||
size = "small",
|
||||
srcs = ["generate_streaming_test_wav_test.py"],
|
||||
additional_deps = [
|
||||
":generate_streaming_test_wav",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "label_wav_cc",
|
||||
srcs = [
|
||||
"label_wav.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "label_wav",
|
||||
srcs = [
|
||||
"label_wav.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "label_wav_test",
|
||||
size = "medium",
|
||||
srcs = ["label_wav_test.py"],
|
||||
additional_deps = [
|
||||
":label_wav",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "recognize_commands",
|
||||
srcs = [
|
||||
"recognize_commands.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"recognize_commands.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "recognize_commands_test",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"recognize_commands_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":recognize_commands",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "accuracy_utils",
|
||||
srcs = [
|
||||
"accuracy_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"accuracy_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "accuracy_utils_test",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"accuracy_utils_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":accuracy_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "test_streaming_accuracy",
|
||||
srcs = [
|
||||
"test_streaming_accuracy.cc",
|
||||
],
|
||||
deps = [
|
||||
":accuracy_utils",
|
||||
":recognize_commands",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
4
tensorflow/examples/speech_commands/README.md
Normal file
4
tensorflow/examples/speech_commands/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# Speech Commands Example
|
||||
|
||||
This is a basic speech recognition example. For more information, see the
|
||||
tutorial at http://tensorflow.org/tutorials/audio_recognition.
|
138
tensorflow/examples/speech_commands/accuracy_utils.cc
Normal file
138
tensorflow/examples/speech_commands/accuracy_utils.cc
Normal file
@ -0,0 +1,138 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/examples/speech_commands/accuracy_utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status ReadGroundTruthFile(const string& file_name,
|
||||
std::vector<std::pair<string, int64>>* result) {
|
||||
std::ifstream file(file_name);
|
||||
if (!file) {
|
||||
return tensorflow::errors::NotFound("Ground truth file '", file_name,
|
||||
"' not found.");
|
||||
}
|
||||
result->clear();
|
||||
string line;
|
||||
while (std::getline(file, line)) {
|
||||
std::vector<string> pieces = tensorflow::str_util::Split(line, ',');
|
||||
if (pieces.size() != 2) {
|
||||
continue;
|
||||
}
|
||||
float timestamp;
|
||||
if (!tensorflow::strings::safe_strtof(pieces[1].c_str(), ×tamp)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Wrong number format at line: ", line);
|
||||
}
|
||||
string label = pieces[0];
|
||||
auto timestamp_int64 = static_cast<int64>(timestamp);
|
||||
result->push_back({label, timestamp_int64});
|
||||
}
|
||||
std::sort(result->begin(), result->end(),
|
||||
[](const std::pair<string, int64>& left,
|
||||
const std::pair<string, int64>& right) {
|
||||
return left.second < right.second;
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CalculateAccuracyStats(
|
||||
const std::vector<std::pair<string, int64>>& ground_truth_list,
|
||||
const std::vector<std::pair<string, int64>>& found_words,
|
||||
int64 up_to_time_ms, int64 time_tolerance_ms,
|
||||
StreamingAccuracyStats* stats) {
|
||||
int64 latest_possible_time;
|
||||
if (up_to_time_ms == -1) {
|
||||
latest_possible_time = std::numeric_limits<int64>::max();
|
||||
} else {
|
||||
latest_possible_time = up_to_time_ms + time_tolerance_ms;
|
||||
}
|
||||
stats->how_many_ground_truth_words = 0;
|
||||
for (const std::pair<string, int64>& ground_truth : ground_truth_list) {
|
||||
const int64 ground_truth_time = ground_truth.second;
|
||||
if (ground_truth_time > latest_possible_time) {
|
||||
break;
|
||||
}
|
||||
++stats->how_many_ground_truth_words;
|
||||
}
|
||||
|
||||
stats->how_many_false_positives = 0;
|
||||
stats->how_many_correct_words = 0;
|
||||
stats->how_many_wrong_words = 0;
|
||||
std::unordered_set<int64> has_ground_truth_been_matched;
|
||||
for (const std::pair<string, int64>& found_word : found_words) {
|
||||
const string& found_label = found_word.first;
|
||||
const int64 found_time = found_word.second;
|
||||
const int64 earliest_time = found_time - time_tolerance_ms;
|
||||
const int64 latest_time = found_time + time_tolerance_ms;
|
||||
bool has_match_been_found = false;
|
||||
for (const std::pair<string, int64>& ground_truth : ground_truth_list) {
|
||||
const int64 ground_truth_time = ground_truth.second;
|
||||
if ((ground_truth_time > latest_time) ||
|
||||
(ground_truth_time > latest_possible_time)) {
|
||||
break;
|
||||
}
|
||||
if (ground_truth_time < earliest_time) {
|
||||
continue;
|
||||
}
|
||||
const string& ground_truth_label = ground_truth.first;
|
||||
if ((ground_truth_label == found_label) &&
|
||||
(has_ground_truth_been_matched.count(ground_truth_time) == 0)) {
|
||||
++stats->how_many_correct_words;
|
||||
} else {
|
||||
++stats->how_many_wrong_words;
|
||||
}
|
||||
has_ground_truth_been_matched.insert(ground_truth_time);
|
||||
has_match_been_found = true;
|
||||
break;
|
||||
}
|
||||
if (!has_match_been_found) {
|
||||
++stats->how_many_false_positives;
|
||||
}
|
||||
}
|
||||
stats->how_many_ground_truth_matched = has_ground_truth_been_matched.size();
|
||||
}
|
||||
|
||||
void PrintAccuracyStats(const StreamingAccuracyStats& stats) {
|
||||
if (stats.how_many_ground_truth_words == 0) {
|
||||
LOG(INFO) << "No ground truth yet, " << stats.how_many_false_positives
|
||||
<< " false positives";
|
||||
} else {
|
||||
float any_match_percentage =
|
||||
(stats.how_many_ground_truth_matched * 100.0f) /
|
||||
stats.how_many_ground_truth_words;
|
||||
float correct_match_percentage = (stats.how_many_correct_words * 100.0f) /
|
||||
stats.how_many_ground_truth_words;
|
||||
float wrong_match_percentage = (stats.how_many_wrong_words * 100.0f) /
|
||||
stats.how_many_ground_truth_words;
|
||||
float false_positive_percentage =
|
||||
(stats.how_many_false_positives * 100.0f) /
|
||||
stats.how_many_ground_truth_words;
|
||||
|
||||
LOG(INFO) << std::setprecision(1) << std::fixed << any_match_percentage
|
||||
<< "% matched, " << correct_match_percentage << "% correctly, "
|
||||
<< wrong_match_percentage << "% wrongly, "
|
||||
<< false_positive_percentage << "% false positives ";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
60
tensorflow/examples/speech_commands/accuracy_utils.h
Normal file
60
tensorflow/examples/speech_commands/accuracy_utils.h
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
struct StreamingAccuracyStats {
|
||||
StreamingAccuracyStats()
|
||||
: how_many_ground_truth_words(0),
|
||||
how_many_ground_truth_matched(0),
|
||||
how_many_false_positives(0),
|
||||
how_many_correct_words(0),
|
||||
how_many_wrong_words(0) {}
|
||||
int32 how_many_ground_truth_words;
|
||||
int32 how_many_ground_truth_matched;
|
||||
int32 how_many_false_positives;
|
||||
int32 how_many_correct_words;
|
||||
int32 how_many_wrong_words;
|
||||
};
|
||||
|
||||
// Takes a file name, and loads a list of expected word labels and times from
|
||||
// it, as comma-separated variables.
|
||||
Status ReadGroundTruthFile(const string& file_name,
|
||||
std::vector<std::pair<string, int64>>* result);
|
||||
|
||||
// Given ground truth labels and corresponding predictions found by a model,
|
||||
// figure out how many were correct. Takes a time limit, so that only
|
||||
// predictions up to a point in time are considered, in case we're evaluating
|
||||
// accuracy when the model has only been run on part of the stream.
|
||||
void CalculateAccuracyStats(
|
||||
const std::vector<std::pair<string, int64>>& ground_truth_list,
|
||||
const std::vector<std::pair<string, int64>>& found_words,
|
||||
int64 up_to_time_ms, int64 time_tolerance_ms,
|
||||
StreamingAccuracyStats* stats);
|
||||
|
||||
// Writes a human-readable description of the statistics to stdout.
|
||||
void PrintAccuracyStats(const StreamingAccuracyStats& stats);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
|
59
tensorflow/examples/speech_commands/accuracy_utils_test.cc
Normal file
59
tensorflow/examples/speech_commands/accuracy_utils_test.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/examples/speech_commands/accuracy_utils.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(AccuracyUtilsTest, ReadGroundTruthFile) {
|
||||
string file_name = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(),
|
||||
"ground_truth.txt");
|
||||
string file_data = "a,10\nb,12\n";
|
||||
TF_ASSERT_OK(WriteStringToFile(Env::Default(), file_name, file_data));
|
||||
|
||||
std::vector<std::pair<string, int64>> ground_truth;
|
||||
TF_ASSERT_OK(ReadGroundTruthFile(file_name, &ground_truth));
|
||||
ASSERT_EQ(2, ground_truth.size());
|
||||
EXPECT_EQ("a", ground_truth[0].first);
|
||||
EXPECT_EQ(10, ground_truth[0].second);
|
||||
EXPECT_EQ("b", ground_truth[1].first);
|
||||
EXPECT_EQ(12, ground_truth[1].second);
|
||||
}
|
||||
|
||||
TEST(AccuracyUtilsTest, CalculateAccuracyStats) {
|
||||
StreamingAccuracyStats stats;
|
||||
CalculateAccuracyStats({{"a", 1000}, {"b", 9000}},
|
||||
{{"a", 1200}, {"b", 5000}, {"a", 8700}}, 10000, 500,
|
||||
&stats);
|
||||
EXPECT_EQ(2, stats.how_many_ground_truth_words);
|
||||
EXPECT_EQ(2, stats.how_many_ground_truth_matched);
|
||||
EXPECT_EQ(1, stats.how_many_false_positives);
|
||||
EXPECT_EQ(1, stats.how_many_correct_words);
|
||||
EXPECT_EQ(1, stats.how_many_wrong_words);
|
||||
}
|
||||
|
||||
TEST(AccuracyUtilsTest, PrintAccuracyStats) {
|
||||
StreamingAccuracyStats stats;
|
||||
PrintAccuracyStats(stats);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
167
tensorflow/examples/speech_commands/freeze.py
Normal file
167
tensorflow/examples/speech_commands/freeze.py
Normal file
@ -0,0 +1,167 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Converts a trained checkpoint into a frozen model for mobile inference.
|
||||
|
||||
Once you've trained a model using the `train.py` script, you can use this tool
|
||||
to convert it into a binary GraphDef file that can be loaded into the Android,
|
||||
iOS, or Raspberry Pi example code. Here's an example of how to run it:
|
||||
|
||||
bazel run tensorflow/examples/speech_commands/freeze -- \
|
||||
--sample_rate=16000 --dct_coefficient_count=40 --window_size_ms=20 \
|
||||
--window_stride_ms=10 --clip_duration_ms=1000 \
|
||||
--model_architecture=conv \
|
||||
--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-1300 \
|
||||
--output_file=/tmp/my_frozen_graph.pb
|
||||
|
||||
One thing to watch out for is that you need to pass in the same arguments for
|
||||
`sample_rate` and other command line variables here as you did for the training
|
||||
script.
|
||||
|
||||
The resulting graph has an input for WAV-encoded data named 'wav_data', one for
|
||||
raw PCM data (as floats in the range -1.0 to 1.0) called 'decoded_sample_data',
|
||||
and the output is called 'labels_softmax'.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||
import tensorflow.examples.speech_commands.input_data as input_data
|
||||
import tensorflow.examples.speech_commands.models as models
|
||||
from tensorflow.python.framework import graph_util
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
||||
window_size_ms, window_stride_ms,
|
||||
dct_coefficient_count, model_architecture):
|
||||
"""Creates an audio model with the nodes needed for inference.
|
||||
|
||||
Uses the supplied arguments to create a model, and inserts the input and
|
||||
output nodes that are needed to use the graph for inference.
|
||||
|
||||
Args:
|
||||
wanted_words: Comma-separated list of the words we're trying to recognize.
|
||||
sample_rate: How many samples per second are in the input audio files.
|
||||
clip_duration_ms: How many samples to analyze for the audio pattern.
|
||||
window_size_ms: Time slice duration to estimate frequencies from.
|
||||
window_stride_ms: How far apart time slices should be.
|
||||
dct_coefficient_count: Number of frequency bands to analyze.
|
||||
model_architecture: Name of the kind of model to generate.
|
||||
"""
|
||||
|
||||
words_list = input_data.prepare_words_list(wanted_words.split(','))
|
||||
model_settings = models.prepare_model_settings(
|
||||
len(words_list), sample_rate, clip_duration_ms, window_size_ms,
|
||||
window_stride_ms, dct_coefficient_count)
|
||||
|
||||
wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
|
||||
decoded_sample_data = contrib_audio.decode_wav(
|
||||
wav_data_placeholder,
|
||||
desired_channels=1,
|
||||
desired_samples=model_settings['desired_samples'],
|
||||
name='decoded_sample_data')
|
||||
spectrogram = contrib_audio.audio_spectrogram(
|
||||
decoded_sample_data.audio,
|
||||
window_size=model_settings['window_size_samples'],
|
||||
stride=model_settings['window_stride_samples'],
|
||||
magnitude_squared=True)
|
||||
fingerprint_input = contrib_audio.mfcc(
|
||||
spectrogram,
|
||||
decoded_sample_data.sample_rate,
|
||||
dct_coefficient_count=dct_coefficient_count)
|
||||
|
||||
logits = models.create_model(
|
||||
fingerprint_input, model_settings, model_architecture, is_training=False)
|
||||
|
||||
# Create an output to use for inference.
|
||||
tf.nn.softmax(logits, name='labels_softmax')
|
||||
|
||||
|
||||
def main(_):
|
||||
|
||||
# Create the model and load its weights.
|
||||
sess = tf.InteractiveSession()
|
||||
create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate,
|
||||
FLAGS.clip_duration_ms, FLAGS.window_size_ms,
|
||||
FLAGS.window_stride_ms, FLAGS.dct_coefficient_count,
|
||||
FLAGS.model_architecture)
|
||||
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
|
||||
|
||||
# Turn all the variables into inline constants inside the graph and save it.
|
||||
frozen_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess, sess.graph_def, ['labels_softmax'])
|
||||
tf.train.write_graph(
|
||||
frozen_graph_def,
|
||||
os.path.dirname(FLAGS.output_file),
|
||||
os.path.basename(FLAGS.output_file),
|
||||
as_text=False)
|
||||
tf.logging.info('Saved frozen graph to %s', FLAGS.output_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--sample_rate',
|
||||
type=int,
|
||||
default=16000,
|
||||
help='Expected sample rate of the wavs',)
|
||||
parser.add_argument(
|
||||
'--clip_duration_ms',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Expected duration in milliseconds of the wavs',)
|
||||
parser.add_argument(
|
||||
'--window_size_ms',
|
||||
type=float,
|
||||
default=20.0,
|
||||
help='How long each spectrogram timeslice is',)
|
||||
parser.add_argument(
|
||||
'--window_stride_ms',
|
||||
type=float,
|
||||
default=10.0,
|
||||
help='How long each spectrogram timeslice is',)
|
||||
parser.add_argument(
|
||||
'--dct_coefficient_count',
|
||||
type=int,
|
||||
default=40,
|
||||
help='How many bins to use for the MFCC fingerprint',)
|
||||
parser.add_argument(
|
||||
'--start_checkpoint',
|
||||
type=str,
|
||||
default='',
|
||||
help='If specified, restore this pretrained model before any training.')
|
||||
parser.add_argument(
|
||||
'--model_architecture',
|
||||
type=str,
|
||||
default='conv',
|
||||
help='What model architecture to use')
|
||||
parser.add_argument(
|
||||
'--wanted_words',
|
||||
type=str,
|
||||
default='yes,no,up,down,left,right,on,off,stop,go',
|
||||
help='Words to use (others will be added to an unknown label)',)
|
||||
parser.add_argument(
|
||||
'--output_file', type=str, help='Where to save the frozen graph.')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
38
tensorflow/examples/speech_commands/freeze_test.py
Normal file
38
tensorflow/examples/speech_commands/freeze_test.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for data input for speech commands."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.examples.speech_commands import freeze
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FreezeTest(test.TestCase):
|
||||
|
||||
def testCreateInferenceGraph(self):
|
||||
with self.test_session() as sess:
|
||||
freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 20.0, 10.0, 40,
|
||||
'conv')
|
||||
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
|
||||
self.assertIsNotNone(
|
||||
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
|
||||
self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -0,0 +1,281 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Saves out a .wav file with synthesized conversational data and labels.
|
||||
|
||||
The best way to estimate the real-world performance of an audio recognition
|
||||
model is by running it against a continuous stream of data, the way that it
|
||||
would be used in an application. Training evaluations are only run against
|
||||
discrete individual samples, so the results aren't as realistic.
|
||||
|
||||
To make it easy to run evaluations against audio streams, this script uses
|
||||
samples from the testing partition of the data set, mixes them in at random
|
||||
positions together with background noise, and saves out the result as one long
|
||||
audio file.
|
||||
|
||||
Here's an example of generating a test file:
|
||||
|
||||
bazel run tensorflow/examples/speech_commands:generate_streaming_test_wav -- \
|
||||
--data_dir=/tmp/my_wavs --background_dir=/tmp/my_backgrounds \
|
||||
--background_volume=0.1 --test_duration_seconds=600 \
|
||||
--output_audio_file=/tmp/streaming_test.wav \
|
||||
--output_labels_file=/tmp/streaming_test_labels.txt
|
||||
|
||||
Once you've created a streaming audio file, you can then use the
|
||||
test_streaming_accuracy tool to calculate accuracy metrics for a model.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import tensorflow.examples.speech_commands.input_data as input_data
|
||||
import tensorflow.examples.speech_commands.models as models
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def mix_in_audio_sample(track_data, track_offset, sample_data, sample_offset,
|
||||
clip_duration, sample_volume, ramp_in, ramp_out):
|
||||
"""Mixes the sample data into the main track at the specified offset.
|
||||
|
||||
Args:
|
||||
track_data: Numpy array holding main audio data. Modified in-place.
|
||||
track_offset: Where to mix the sample into the main track.
|
||||
sample_data: Numpy array of audio data to mix into the main track.
|
||||
sample_offset: Where to start in the audio sample.
|
||||
clip_duration: How long the sample segment is.
|
||||
sample_volume: Loudness to mix the sample in at.
|
||||
ramp_in: Length in samples of volume increase stage.
|
||||
ramp_out: Length in samples of volume decrease stage.
|
||||
"""
|
||||
ramp_out_index = clip_duration - ramp_out
|
||||
track_end = min(track_offset + clip_duration, track_data.shape[0])
|
||||
track_end = min(track_end,
|
||||
track_offset + (sample_data.shape[0] - sample_offset))
|
||||
sample_range = track_end - track_offset
|
||||
for i in range(sample_range):
|
||||
if i < ramp_in:
|
||||
envelope_scale = i / ramp_in
|
||||
elif i > ramp_out_index:
|
||||
envelope_scale = (clip_duration - i) / ramp_out
|
||||
else:
|
||||
envelope_scale = 1
|
||||
sample_input = sample_data[sample_offset + i]
|
||||
track_data[track_offset
|
||||
+ i] += sample_input * envelope_scale * sample_volume
|
||||
|
||||
|
||||
def main(_):
|
||||
audio_processor = input_data.AudioProcessor(
|
||||
'', FLAGS.data_dir, FLAGS.silence_percentage, 10,
|
||||
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
|
||||
FLAGS.testing_percentage)
|
||||
words_list = input_data.prepare_words_list(FLAGS.wanted_words.split(','))
|
||||
model_settings = models.prepare_model_settings(
|
||||
len(words_list), FLAGS.sample_rate, FLAGS.clip_duration_ms,
|
||||
FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
|
||||
|
||||
output_audio_sample_count = FLAGS.sample_rate * FLAGS.test_duration_seconds
|
||||
output_audio = np.zeros((output_audio_sample_count,), dtype=np.float32)
|
||||
|
||||
# Set up background audio.
|
||||
background_crossover_ms = 500
|
||||
background_segment_duration_ms = (
|
||||
FLAGS.clip_duration_ms + background_crossover_ms)
|
||||
background_segment_duration_samples = int(
|
||||
(background_segment_duration_ms * FLAGS.sample_rate) / 1000)
|
||||
background_segment_stride_samples = int(
|
||||
(FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000)
|
||||
background_ramp_samples = int(
|
||||
((background_crossover_ms / 2) * FLAGS.sample_rate) / 1000)
|
||||
|
||||
# Mix the background audio into the main track.
|
||||
how_many_backgrounds = int(
|
||||
math.ceil(output_audio_sample_count / background_segment_stride_samples))
|
||||
for i in range(how_many_backgrounds):
|
||||
output_offset = int(i * background_segment_stride_samples)
|
||||
background_index = np.random.randint(len(audio_processor.background_data))
|
||||
background_samples = audio_processor.background_data[background_index]
|
||||
background_offset = np.random.randint(
|
||||
0, len(background_samples) - model_settings['desired_samples'])
|
||||
background_volume = np.random.uniform(0, FLAGS.background_volume)
|
||||
mix_in_audio_sample(output_audio, output_offset, background_samples,
|
||||
background_offset, background_segment_duration_samples,
|
||||
background_volume, background_ramp_samples,
|
||||
background_ramp_samples)
|
||||
|
||||
# Mix the words into the main track, noting their labels and positions.
|
||||
output_labels = []
|
||||
word_stride_ms = FLAGS.clip_duration_ms + FLAGS.word_gap_ms
|
||||
word_stride_samples = int((word_stride_ms * FLAGS.sample_rate) / 1000)
|
||||
clip_duration_samples = int(
|
||||
(FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000)
|
||||
word_gap_samples = int((FLAGS.word_gap_ms * FLAGS.sample_rate) / 1000)
|
||||
how_many_words = int(
|
||||
math.floor(output_audio_sample_count / word_stride_samples))
|
||||
all_test_data, all_test_labels = audio_processor.get_unprocessed_data(
|
||||
-1, model_settings, 'testing')
|
||||
for i in range(how_many_words):
|
||||
output_offset = (
|
||||
int(i * word_stride_samples) + np.random.randint(word_gap_samples))
|
||||
output_offset_ms = (output_offset * 1000) / FLAGS.sample_rate
|
||||
is_unknown = np.random.randint(100) < FLAGS.unknown_percentage
|
||||
if is_unknown:
|
||||
wanted_label = input_data.UNKNOWN_WORD_LABEL
|
||||
else:
|
||||
wanted_label = words_list[2 + np.random.randint(len(words_list) - 2)]
|
||||
test_data_start = np.random.randint(len(all_test_data))
|
||||
found_sample_data = None
|
||||
index_lookup = np.arange(len(all_test_data), dtype=np.int32)
|
||||
np.random.shuffle(index_lookup)
|
||||
for test_data_offset in range(len(all_test_data)):
|
||||
test_data_index = index_lookup[(
|
||||
test_data_start + test_data_offset) % len(all_test_data)]
|
||||
current_label = all_test_labels[test_data_index]
|
||||
if current_label == wanted_label:
|
||||
found_sample_data = all_test_data[test_data_index]
|
||||
break
|
||||
mix_in_audio_sample(output_audio, output_offset, found_sample_data, 0,
|
||||
clip_duration_samples, 1.0, 500, 500)
|
||||
output_labels.append({'label': wanted_label, 'time': output_offset_ms})
|
||||
|
||||
input_data.save_wav_file(FLAGS.output_audio_file, output_audio,
|
||||
FLAGS.sample_rate)
|
||||
tf.logging.info('Saved streaming test wav to %s', FLAGS.output_audio_file)
|
||||
|
||||
with open(FLAGS.output_labels_file, 'w') as f:
|
||||
for output_label in output_labels:
|
||||
f.write('%s, %f\n' % (output_label['label'], output_label['time']))
|
||||
tf.logging.info('Saved streaming test labels to %s', FLAGS.output_labels_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--data_url',
|
||||
type=str,
|
||||
# pylint: disable=line-too-long
|
||||
default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
|
||||
# pylint: enable=line-too-long
|
||||
help='Location of speech training data')
|
||||
parser.add_argument(
|
||||
'--data_dir',
|
||||
type=str,
|
||||
default='/tmp/speech_dataset',
|
||||
help="""\
|
||||
Where to download the speech training data to.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--background_dir',
|
||||
type=str,
|
||||
default='',
|
||||
help="""\
|
||||
Path to a directory of .wav files to mix in as background noise during training.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--background_volume',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="""\
|
||||
How loud the background noise should be, between 0 and 1.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--background_frequency',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="""\
|
||||
How many of the training samples have background noise mixed in.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--silence_percentage',
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="""\
|
||||
How much of the training data should be silence.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--testing_percentage',
|
||||
type=int,
|
||||
default=10,
|
||||
help='What percentage of wavs to use as a test set.')
|
||||
parser.add_argument(
|
||||
'--validation_percentage',
|
||||
type=int,
|
||||
default=10,
|
||||
help='What percentage of wavs to use as a validation set.')
|
||||
parser.add_argument(
|
||||
'--sample_rate',
|
||||
type=int,
|
||||
default=16000,
|
||||
help='Expected sample rate of the wavs.',)
|
||||
parser.add_argument(
|
||||
'--clip_duration_ms',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Expected duration in milliseconds of the wavs.',)
|
||||
parser.add_argument(
|
||||
'--window_size_ms',
|
||||
type=float,
|
||||
default=20.0,
|
||||
help='How long each spectrogram timeslice is',)
|
||||
parser.add_argument(
|
||||
'--window_stride_ms',
|
||||
type=float,
|
||||
default=10.0,
|
||||
help='How long each spectrogram timeslice is',)
|
||||
parser.add_argument(
|
||||
'--dct_coefficient_count',
|
||||
type=int,
|
||||
default=40,
|
||||
help='How many bins to use for the MFCC fingerprint',)
|
||||
parser.add_argument(
|
||||
'--wanted_words',
|
||||
type=str,
|
||||
default='yes,no,up,down,left,right,on,off,stop,go',
|
||||
help='Words to use (others will be added to an unknown label)',)
|
||||
parser.add_argument(
|
||||
'--output_audio_file',
|
||||
type=str,
|
||||
default='/tmp/speech_commands_train/streaming_test.wav',
|
||||
help='File to save the generated test audio to.')
|
||||
parser.add_argument(
|
||||
'--output_labels_file',
|
||||
type=str,
|
||||
default='/tmp/speech_commands_train/streaming_test_labels.txt',
|
||||
help='File to save the generated test labels to.')
|
||||
parser.add_argument(
|
||||
'--test_duration_seconds',
|
||||
type=int,
|
||||
default=600,
|
||||
help='How long the generated test audio file should be.',)
|
||||
parser.add_argument(
|
||||
'--word_gap_ms',
|
||||
type=int,
|
||||
default=2000,
|
||||
help='How long the average gap should be between words.',)
|
||||
parser.add_argument(
|
||||
'--unknown_percentage',
|
||||
type=int,
|
||||
default=30,
|
||||
help='What percentage of words should be unknown.')
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -0,0 +1,39 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for test file generation for speech commands."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.examples.speech_commands import generate_streaming_test_wav
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GenerateStreamingTestWavTest(test.TestCase):
|
||||
|
||||
def testMixInAudioSample(self):
|
||||
track_data = np.zeros([10000])
|
||||
sample_data = np.ones([1000])
|
||||
generate_streaming_test_wav.mix_in_audio_sample(
|
||||
track_data, 2000, sample_data, 0, 1000, 1.0, 100, 100)
|
||||
self.assertNear(1.0, track_data[2500], 0.0001)
|
||||
self.assertNear(0.0, track_data[3500], 0.0001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
532
tensorflow/examples/speech_commands/input_data.py
Normal file
532
tensorflow/examples/speech_commands/input_data.py
Normal file
@ -0,0 +1,532 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Model definitions for simple speech recognition.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import hashlib
|
||||
import math
|
||||
import os.path
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
from six.moves import urllib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
MAX_NUM_WAVS_PER_CLASS = 2**27 - 1 # ~134M
|
||||
SILENCE_LABEL = '_silence_'
|
||||
SILENCE_INDEX = 0
|
||||
UNKNOWN_WORD_LABEL = '_unknown_'
|
||||
UNKNOWN_WORD_INDEX = 1
|
||||
BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
|
||||
RANDOM_SEED = 59185
|
||||
|
||||
|
||||
def prepare_words_list(wanted_words):
|
||||
"""Prepends common tokens to the custom word list.
|
||||
|
||||
Args:
|
||||
wanted_words: List of strings containing the custom words.
|
||||
|
||||
Returns:
|
||||
List with the standard silence and unknown tokens added.
|
||||
"""
|
||||
return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words
|
||||
|
||||
|
||||
def which_set(filename, validation_percentage, testing_percentage):
|
||||
"""Determines which data partition the file should belong to.
|
||||
|
||||
We want to keep files in the same training, validation, or testing sets even
|
||||
if new ones are added over time. This makes it less likely that testing
|
||||
samples will accidentally be reused in training when long runs are restarted
|
||||
for example. To keep this stability, a hash of the filename is taken and used
|
||||
to determine which set it should belong to. This determination only depends on
|
||||
the name and the set proportions, so it won't change as other files are added.
|
||||
|
||||
It's also useful to associate particular files as related (for example words
|
||||
spoken by the same person), so anything after '_nohash_' in a filename is
|
||||
ignored for set determination. This ensures that 'bobby_nohash_0.wav' and
|
||||
'bobby_nohash_1.wav' are always in the same set, for example.
|
||||
|
||||
Args:
|
||||
filename: File path of the data sample.
|
||||
validation_percentage: How much of the data set to use for validation.
|
||||
testing_percentage: How much of the data set to use for testing.
|
||||
|
||||
Returns:
|
||||
String, one of 'training', 'validation', or 'testing'.
|
||||
"""
|
||||
base_name = os.path.basename(filename)
|
||||
# We want to ignore anything after '_nohash_' in the file name when
|
||||
# deciding which set to put a wav in, so the data set creator has a way of
|
||||
# grouping wavs that are close variations of each other.
|
||||
hash_name = re.sub(r'_nohash_.*$', '', base_name)
|
||||
# This looks a bit magical, but we need to decide whether this file should
|
||||
# go into the training, testing, or validation sets, and we want to keep
|
||||
# existing files in the same set even if more files are subsequently
|
||||
# added.
|
||||
# To do that, we need a stable way of deciding based on just the file name
|
||||
# itself, so we do a hash of that and then use that to generate a
|
||||
# probability value that we use to assign it.
|
||||
hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
|
||||
percentage_hash = ((int(hash_name_hashed, 16) %
|
||||
(MAX_NUM_WAVS_PER_CLASS + 1)) *
|
||||
(100.0 / MAX_NUM_WAVS_PER_CLASS))
|
||||
if percentage_hash < validation_percentage:
|
||||
result = 'validation'
|
||||
elif percentage_hash < (testing_percentage + validation_percentage):
|
||||
result = 'testing'
|
||||
else:
|
||||
result = 'training'
|
||||
return result
|
||||
|
||||
|
||||
def load_wav_file(filename):
|
||||
"""Loads an audio file and returns a float PCM-encoded array of samples.
|
||||
|
||||
Args:
|
||||
filename: Path to the .wav file to load.
|
||||
|
||||
Returns:
|
||||
Numpy array holding the sample data as floats between -1.0 and 1.0.
|
||||
"""
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
wav_filename_placeholder = tf.placeholder(tf.string, [])
|
||||
wav_loader = io_ops.read_file(wav_filename_placeholder)
|
||||
wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
|
||||
return sess.run(
|
||||
wav_decoder,
|
||||
feed_dict={wav_filename_placeholder: filename}).audio.flatten()
|
||||
|
||||
|
||||
def save_wav_file(filename, wav_data, sample_rate):
|
||||
"""Saves audio sample data to a .wav audio file.
|
||||
|
||||
Args:
|
||||
filename: Path to save the file to.
|
||||
wav_data: 2D array of float PCM-encoded audio data.
|
||||
sample_rate: Samples per second to encode in the file.
|
||||
"""
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
wav_filename_placeholder = tf.placeholder(tf.string, [])
|
||||
sample_rate_placeholder = tf.placeholder(tf.int32, [])
|
||||
wav_data_placeholder = tf.placeholder(tf.float32, [None, 1])
|
||||
wav_encoder = contrib_audio.encode_wav(wav_data_placeholder,
|
||||
sample_rate_placeholder)
|
||||
wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder)
|
||||
sess.run(
|
||||
wav_saver,
|
||||
feed_dict={
|
||||
wav_filename_placeholder: filename,
|
||||
sample_rate_placeholder: sample_rate,
|
||||
wav_data_placeholder: np.reshape(wav_data, (-1, 1))
|
||||
})
|
||||
|
||||
|
||||
class AudioProcessor(object):
|
||||
"""Handles loading, partitioning, and preparing audio training data."""
|
||||
|
||||
def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
|
||||
wanted_words, validation_percentage, testing_percentage,
|
||||
model_settings):
|
||||
self.data_dir = data_dir
|
||||
self.maybe_download_and_extract_dataset(data_url, data_dir)
|
||||
self.prepare_data_index(silence_percentage, unknown_percentage,
|
||||
wanted_words, validation_percentage,
|
||||
testing_percentage)
|
||||
self.prepare_background_data()
|
||||
self.prepare_processing_graph(model_settings)
|
||||
|
||||
def maybe_download_and_extract_dataset(self, data_url, dest_directory):
|
||||
"""Download and extract data set tar file.
|
||||
|
||||
If the data set we're using doesn't already exist, this function
|
||||
downloads it from the TensorFlow.org website and unpacks it into a
|
||||
directory.
|
||||
If the data_url is none, don't download anything and expect the data
|
||||
directory to contain the correct files already.
|
||||
|
||||
Args:
|
||||
data_url: Web location of the tar file containing the data set.
|
||||
dest_directory: File path to extract data to.
|
||||
"""
|
||||
if not data_url:
|
||||
return
|
||||
if not os.path.exists(dest_directory):
|
||||
os.makedirs(dest_directory)
|
||||
filename = data_url.split('/')[-1]
|
||||
filepath = os.path.join(dest_directory, filename)
|
||||
if not os.path.exists(filepath):
|
||||
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write(
|
||||
'\r>> Downloading %s %.1f%%' %
|
||||
(filename, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
|
||||
try:
|
||||
filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
|
||||
except:
|
||||
tf.logging.error('Failed to download URL: %s to folder: %s', data_url,
|
||||
filepath)
|
||||
tf.logging.error('Please make sure you have enough free space and'
|
||||
' an internet connection')
|
||||
raise
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
tf.logging.info('Successfully downloaded %s (%d bytes)', filename,
|
||||
statinfo.st_size)
|
||||
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
|
||||
|
||||
def prepare_data_index(self, silence_percentage, unknown_percentage,
|
||||
wanted_words, validation_percentage,
|
||||
testing_percentage):
|
||||
"""Prepares a list of the samples organized by set and label.
|
||||
|
||||
The training loop needs a list of all the available data, organized by
|
||||
which partition it should belong to, and with ground truth labels attached.
|
||||
This function analyzes the folders below the `data_dir`, figures out the
|
||||
right
|
||||
labels for each file based on the name of the subdirectory it belongs to,
|
||||
and uses a stable hash to assign it to a data set partition.
|
||||
|
||||
Args:
|
||||
silence_percentage: How much of the resulting data should be background.
|
||||
unknown_percentage: How much should be audio outside the wanted classes.
|
||||
wanted_words: Labels of the classes we want to be able to recognize.
|
||||
validation_percentage: How much of the data set to use for validation.
|
||||
testing_percentage: How much of the data set to use for testing.
|
||||
|
||||
Returns:
|
||||
Dictionary containing a list of file information for each set partition,
|
||||
and a lookup map for each class to determine its numeric index.
|
||||
|
||||
Raises:
|
||||
Exception: If expected files are not found.
|
||||
"""
|
||||
# Make sure the shuffling and picking of unknowns is deterministic.
|
||||
random.seed(RANDOM_SEED)
|
||||
wanted_words_index = {}
|
||||
for index, wanted_word in enumerate(wanted_words):
|
||||
wanted_words_index[wanted_word] = index + 2
|
||||
self.data_index = {'validation': [], 'testing': [], 'training': []}
|
||||
unknown_index = {'validation': [], 'testing': [], 'training': []}
|
||||
all_words = {}
|
||||
# Look through all the subfolders to find audio samples
|
||||
search_path = os.path.join(self.data_dir, '*', '*.wav')
|
||||
for wav_path in gfile.Glob(search_path):
|
||||
word = re.search('.*/([^/]+)/.*.wav', wav_path).group(1).lower()
|
||||
# Treat the '_background_noise_' folder as a special case, since we expect
|
||||
# it to contain long audio samples we mix in to improve training.
|
||||
if word == BACKGROUND_NOISE_DIR_NAME:
|
||||
continue
|
||||
all_words[word] = True
|
||||
set_index = which_set(wav_path, validation_percentage, testing_percentage)
|
||||
# If it's a known class, store its detail, otherwise add it to the list
|
||||
# we'll use to train the unknown label.
|
||||
if word in wanted_words_index:
|
||||
self.data_index[set_index].append({'label': word, 'file': wav_path})
|
||||
else:
|
||||
unknown_index[set_index].append({'label': word, 'file': wav_path})
|
||||
if not all_words:
|
||||
raise Exception('No .wavs found at ' + search_path)
|
||||
for index, wanted_word in enumerate(wanted_words):
|
||||
if wanted_word not in all_words:
|
||||
raise Exception('Expected to find ' + wanted_word +
|
||||
' in labels but only found ' +
|
||||
', '.join(all_words.keys()))
|
||||
# We need an arbitrary file to load as the input for the silence samples.
|
||||
# It's multiplied by zero later, so the content doesn't matter.
|
||||
silence_wav_path = self.data_index['training'][0]['file']
|
||||
for set_index in ['validation', 'testing', 'training']:
|
||||
set_size = len(self.data_index[set_index])
|
||||
silence_size = int(math.ceil(set_size * silence_percentage / 100))
|
||||
for _ in range(silence_size):
|
||||
self.data_index[set_index].append({
|
||||
'label': SILENCE_LABEL,
|
||||
'file': silence_wav_path
|
||||
})
|
||||
# Pick some unknowns to add to each partition of the data set.
|
||||
random.shuffle(unknown_index[set_index])
|
||||
unknown_size = int(math.ceil(set_size * unknown_percentage / 100))
|
||||
self.data_index[set_index].extend(unknown_index[set_index][:unknown_size])
|
||||
# Make sure the ordering is random.
|
||||
for set_index in ['validation', 'testing', 'training']:
|
||||
random.shuffle(self.data_index[set_index])
|
||||
# Prepare the rest of the result data structure.
|
||||
self.words_list = prepare_words_list(wanted_words)
|
||||
self.word_to_index = {}
|
||||
for word in all_words:
|
||||
if word in wanted_words_index:
|
||||
self.word_to_index[word] = wanted_words_index[word]
|
||||
else:
|
||||
self.word_to_index[word] = UNKNOWN_WORD_INDEX
|
||||
self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX
|
||||
|
||||
def prepare_background_data(self):
|
||||
"""Searches a folder for background noise audio, and loads it into memory.
|
||||
|
||||
It's expected that the background audio samples will be in a subdirectory
|
||||
named '_background_noise_' inside the 'data_dir' folder, as .wavs that match
|
||||
the sample rate of the training data, but can be much longer in duration.
|
||||
|
||||
If the '_background_noise_' folder doesn't exist at all, this isn't an
|
||||
error, it's just taken to mean that no background noise augmentation should
|
||||
be used. If the folder does exist, but it's empty, that's treated as an
|
||||
error.
|
||||
|
||||
Returns:
|
||||
List of raw PCM-encoded audio samples of background noise.
|
||||
|
||||
Raises:
|
||||
Exception: If files aren't found in the folder.
|
||||
"""
|
||||
self.background_data = []
|
||||
background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME)
|
||||
if not os.path.exists(background_dir):
|
||||
return self.background_data
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
wav_filename_placeholder = tf.placeholder(tf.string, [])
|
||||
wav_loader = io_ops.read_file(wav_filename_placeholder)
|
||||
wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
|
||||
search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME,
|
||||
'*.wav')
|
||||
for wav_path in gfile.Glob(search_path):
|
||||
wav_data = sess.run(
|
||||
wav_decoder,
|
||||
feed_dict={wav_filename_placeholder: wav_path}).audio.flatten()
|
||||
self.background_data.append(wav_data)
|
||||
if not self.background_data:
|
||||
raise Exception('No background wav files were found in ' + search_path)
|
||||
|
||||
def prepare_processing_graph(self, model_settings):
|
||||
"""Builds a TensorFlow graph to apply the input distortions.
|
||||
|
||||
Creates a graph that loads a WAVE file, decodes it, scales the volume,
|
||||
shifts it in time, adds in background noise, calculates a spectrogram, and
|
||||
then builds an MFCC fingerprint from that.
|
||||
|
||||
This must be called with an active TensorFlow session running, and it
|
||||
creates multiple placeholder inputs, and one output:
|
||||
|
||||
- wav_filename_placeholder_: Filename of the WAV to load.
|
||||
- foreground_volume_placeholder_: How loud the main clip should be.
|
||||
- time_shift_padding_placeholder_: Where to pad the clip.
|
||||
- time_shift_offset_placeholder_: How much to move the clip in time.
|
||||
- background_data_placeholder_: PCM sample data for background noise.
|
||||
- background_volume_placeholder_: Loudness of mixed-in background.
|
||||
- mfcc_: Output 2D fingerprint of processed audio.
|
||||
|
||||
Args:
|
||||
model_settings: Information about the current model being trained.
|
||||
"""
|
||||
desired_samples = model_settings['desired_samples']
|
||||
self.wav_filename_placeholder_ = tf.placeholder(tf.string, [])
|
||||
wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
|
||||
wav_decoder = contrib_audio.decode_wav(
|
||||
wav_loader, desired_channels=1, desired_samples=desired_samples)
|
||||
# Allow the audio sample's volume to be adjusted.
|
||||
self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, [])
|
||||
scaled_foreground = tf.multiply(wav_decoder.audio,
|
||||
self.foreground_volume_placeholder_)
|
||||
# Shift the sample's start position, and pad any gaps with zeros.
|
||||
self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2])
|
||||
self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2])
|
||||
padded_foreground = tf.pad(
|
||||
scaled_foreground,
|
||||
self.time_shift_padding_placeholder_,
|
||||
mode='CONSTANT')
|
||||
sliced_foreground = tf.slice(padded_foreground,
|
||||
self.time_shift_offset_placeholder_,
|
||||
[desired_samples, -1])
|
||||
# Mix in background noise.
|
||||
self.background_data_placeholder_ = tf.placeholder(tf.float32,
|
||||
[desired_samples, 1])
|
||||
self.background_volume_placeholder_ = tf.placeholder(tf.float32, [])
|
||||
background_mul = tf.multiply(self.background_data_placeholder_,
|
||||
self.background_volume_placeholder_)
|
||||
background_add = tf.add(background_mul, sliced_foreground)
|
||||
background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
|
||||
# Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
|
||||
spectrogram = contrib_audio.audio_spectrogram(
|
||||
background_clamp,
|
||||
window_size=model_settings['window_size_samples'],
|
||||
stride=model_settings['window_stride_samples'],
|
||||
magnitude_squared=True)
|
||||
self.mfcc_ = contrib_audio.mfcc(
|
||||
spectrogram,
|
||||
wav_decoder.sample_rate,
|
||||
dct_coefficient_count=model_settings['dct_coefficient_count'])
|
||||
|
||||
def set_size(self, mode):
|
||||
"""Calculates the number of samples in the dataset partition.
|
||||
|
||||
Args:
|
||||
mode: Which partition, must be 'training', 'validation', or 'testing'.
|
||||
|
||||
Returns:
|
||||
Number of samples in the partition.
|
||||
"""
|
||||
return len(self.data_index[mode])
|
||||
|
||||
def get_data(self, how_many, offset, model_settings, background_frequency,
|
||||
background_volume_range, time_shift, mode, sess):
|
||||
"""Gather samples from the data set, applying transformations as needed.
|
||||
|
||||
When the mode is 'training', a random selection of samples will be returned,
|
||||
otherwise the first N clips in the partition will be used. This ensures that
|
||||
validation always uses the same samples, reducing noise in the metrics.
|
||||
|
||||
Args:
|
||||
how_many: Desired number of samples to return. -1 means the entire
|
||||
contents of this partition.
|
||||
offset: Where to start when fetching deterministically.
|
||||
model_settings: Information about the current model being trained.
|
||||
background_frequency: How many clips will have background noise, 0.0 to
|
||||
1.0.
|
||||
background_volume_range: How loud the background noise will be.
|
||||
time_shift: How much to randomly shift the clips by in time.
|
||||
mode: Which partition to use, must be 'training', 'validation', or
|
||||
'testing'.
|
||||
sess: TensorFlow session that was active when processor was created.
|
||||
|
||||
Returns:
|
||||
List of sample data for the transformed samples, and list of labels in
|
||||
one-hot form.
|
||||
"""
|
||||
# Pick one of the partitions to choose samples from.
|
||||
candidates = self.data_index[mode]
|
||||
if how_many == -1:
|
||||
sample_count = len(candidates)
|
||||
else:
|
||||
sample_count = max(0, min(how_many, len(candidates) - offset))
|
||||
# Data and labels will be populated and returned.
|
||||
data = np.zeros((sample_count, model_settings['fingerprint_size']))
|
||||
labels = np.zeros((sample_count, model_settings['label_count']))
|
||||
desired_samples = model_settings['desired_samples']
|
||||
use_background = self.background_data and (mode == 'training')
|
||||
pick_deterministically = (mode != 'training')
|
||||
# Use the processing graph we created earlier to repeatedly to generate the
|
||||
# final output sample data we'll use in training.
|
||||
for i in xrange(offset, offset + sample_count):
|
||||
# Pick which audio sample to use.
|
||||
if how_many == -1 or pick_deterministically:
|
||||
sample_index = i
|
||||
else:
|
||||
sample_index = np.random.randint(len(candidates))
|
||||
sample = candidates[sample_index]
|
||||
# If we're time shifting, set up the offset for this sample.
|
||||
if time_shift > 0:
|
||||
time_shift_amount = np.random.randint(-time_shift, time_shift)
|
||||
else:
|
||||
time_shift_amount = 0
|
||||
if time_shift_amount > 0:
|
||||
time_shift_padding = [[time_shift_amount, 0], [0, 0]]
|
||||
time_shift_offset = [0, 0]
|
||||
else:
|
||||
time_shift_padding = [[0, -time_shift_amount], [0, 0]]
|
||||
time_shift_offset = [-time_shift_amount, 0]
|
||||
input_dict = {
|
||||
self.wav_filename_placeholder_: sample['file'],
|
||||
self.time_shift_padding_placeholder_: time_shift_padding,
|
||||
self.time_shift_offset_placeholder_: time_shift_offset,
|
||||
}
|
||||
# Choose a section of background noise to mix in.
|
||||
if use_background:
|
||||
background_index = np.random.randint(len(self.background_data))
|
||||
background_samples = self.background_data[background_index]
|
||||
background_offset = np.random.randint(
|
||||
0, len(background_samples) - model_settings['desired_samples'])
|
||||
background_clipped = background_samples[background_offset:(
|
||||
background_offset + desired_samples)]
|
||||
background_reshaped = background_clipped.reshape([desired_samples, 1])
|
||||
if np.random.uniform(0, 1) < background_frequency:
|
||||
background_volume = np.random.uniform(0, background_volume_range)
|
||||
else:
|
||||
background_volume = 0
|
||||
else:
|
||||
background_reshaped = np.zeros([desired_samples, 1])
|
||||
background_volume = 0
|
||||
input_dict[self.background_data_placeholder_] = background_reshaped
|
||||
input_dict[self.background_volume_placeholder_] = background_volume
|
||||
# If we want silence, mute out the main sample but leave the background.
|
||||
if sample['label'] == SILENCE_LABEL:
|
||||
input_dict[self.foreground_volume_placeholder_] = 0
|
||||
else:
|
||||
input_dict[self.foreground_volume_placeholder_] = 1
|
||||
# Run the graph to produce the output audio.
|
||||
data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten()
|
||||
label_index = self.word_to_index[sample['label']]
|
||||
labels[i - offset, label_index] = 1
|
||||
return data, labels
|
||||
|
||||
def get_unprocessed_data(self, how_many, model_settings, mode):
|
||||
"""Retrieve sample data for the given partition, with no transformations.
|
||||
|
||||
Args:
|
||||
how_many: Desired number of samples to return. -1 means the entire
|
||||
contents of this partition.
|
||||
model_settings: Information about the current model being trained.
|
||||
mode: Which partition to use, must be 'training', 'validation', or
|
||||
'testing'.
|
||||
|
||||
Returns:
|
||||
List of sample data for the samples, and list of labels in one-hot form.
|
||||
"""
|
||||
candidates = self.data_index[mode]
|
||||
if how_many == -1:
|
||||
sample_count = len(candidates)
|
||||
else:
|
||||
sample_count = how_many
|
||||
desired_samples = model_settings['desired_samples']
|
||||
words_list = self.words_list
|
||||
data = np.zeros((sample_count, desired_samples))
|
||||
labels = []
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
wav_filename_placeholder = tf.placeholder(tf.string, [])
|
||||
wav_loader = io_ops.read_file(wav_filename_placeholder)
|
||||
wav_decoder = contrib_audio.decode_wav(
|
||||
wav_loader, desired_channels=1, desired_samples=desired_samples)
|
||||
foreground_volume_placeholder = tf.placeholder(tf.float32, [])
|
||||
scaled_foreground = tf.multiply(wav_decoder.audio,
|
||||
foreground_volume_placeholder)
|
||||
for i in range(sample_count):
|
||||
if how_many == -1:
|
||||
sample_index = i
|
||||
else:
|
||||
sample_index = np.random.randint(len(candidates))
|
||||
sample = candidates[sample_index]
|
||||
input_dict = {wav_filename_placeholder: sample['file']}
|
||||
if sample['label'] == SILENCE_LABEL:
|
||||
input_dict[foreground_volume_placeholder] = 0
|
||||
else:
|
||||
input_dict[foreground_volume_placeholder] = 1
|
||||
data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten()
|
||||
label_index = self.word_to_index[sample['label']]
|
||||
labels.append(words_list[label_index])
|
||||
return data, labels
|
212
tensorflow/examples/speech_commands/input_data_test.py
Normal file
212
tensorflow/examples/speech_commands/input_data_test.py
Normal file
@ -0,0 +1,212 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for data input for speech commands."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||
from tensorflow.examples.speech_commands import input_data
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class InputDataTest(test.TestCase):
|
||||
|
||||
def _getWavData(self):
|
||||
with self.test_session() as sess:
|
||||
sample_data = tf.zeros([1000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
with open(filename, "wb") as f:
|
||||
f.write(wav_data)
|
||||
|
||||
def _saveWavFolders(self, root_dir, labels, how_many):
|
||||
wav_data = self._getWavData()
|
||||
for label in labels:
|
||||
dir_name = os.path.join(root_dir, label)
|
||||
os.mkdir(dir_name)
|
||||
for i in range(how_many):
|
||||
file_path = os.path.join(dir_name, "some_audio_%d.wav" % i)
|
||||
self._saveTestWavFile(file_path, wav_data)
|
||||
|
||||
def _model_settings(self):
|
||||
return {
|
||||
"desired_samples": 160,
|
||||
"fingerprint_size": 40,
|
||||
"label_count": 4,
|
||||
"window_size_samples": 100,
|
||||
"window_stride_samples": 100,
|
||||
"dct_coefficient_count": 40,
|
||||
}
|
||||
|
||||
def testPrepareWordsList(self):
|
||||
words_list = ["a", "b"]
|
||||
self.assertGreater(
|
||||
len(input_data.prepare_words_list(words_list)), len(words_list))
|
||||
|
||||
def testWhichSet(self):
|
||||
self.assertEqual(
|
||||
input_data.which_set("foo.wav", 10, 10),
|
||||
input_data.which_set("foo.wav", 10, 10))
|
||||
self.assertEqual(
|
||||
input_data.which_set("foo_nohash_0.wav", 10, 10),
|
||||
input_data.which_set("foo_nohash_1.wav", 10, 10))
|
||||
|
||||
def testPrepareDataIndex(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
|
||||
audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
|
||||
10, 10, self._model_settings())
|
||||
self.assertLess(0, audio_processor.set_size("training"))
|
||||
self.assertTrue("training" in audio_processor.data_index)
|
||||
self.assertTrue("validation" in audio_processor.data_index)
|
||||
self.assertTrue("testing" in audio_processor.data_index)
|
||||
self.assertEquals(input_data.UNKNOWN_WORD_INDEX,
|
||||
audio_processor.word_to_index["c"])
|
||||
|
||||
def testPrepareDataIndexEmpty(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0)
|
||||
with self.assertRaises(Exception) as e:
|
||||
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
|
||||
self._model_settings())
|
||||
self.assertTrue("No .wavs found" in str(e.exception))
|
||||
|
||||
def testPrepareDataIndexMissing(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
|
||||
with self.assertRaises(Exception) as e:
|
||||
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
|
||||
10, self._model_settings())
|
||||
self.assertTrue("Expected to find" in str(e.exception))
|
||||
|
||||
def testPrepareBackgroundData(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
background_dir = os.path.join(tmp_dir, "_background_noise_")
|
||||
os.mkdir(background_dir)
|
||||
wav_data = self._getWavData()
|
||||
for i in range(10):
|
||||
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
|
||||
self._saveTestWavFile(file_path, wav_data)
|
||||
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
|
||||
audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
|
||||
10, 10, self._model_settings())
|
||||
self.assertEqual(10, len(audio_processor.background_data))
|
||||
|
||||
def testLoadWavFile(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
file_path = os.path.join(tmp_dir, "load_test.wav")
|
||||
wav_data = self._getWavData()
|
||||
self._saveTestWavFile(file_path, wav_data)
|
||||
sample_data = input_data.load_wav_file(file_path)
|
||||
self.assertIsNotNone(sample_data)
|
||||
|
||||
def testSaveWavFile(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
file_path = os.path.join(tmp_dir, "load_test.wav")
|
||||
save_data = np.zeros([16000, 1])
|
||||
input_data.save_wav_file(file_path, save_data, 16000)
|
||||
loaded_data = input_data.load_wav_file(file_path)
|
||||
self.assertIsNotNone(loaded_data)
|
||||
self.assertEqual(16000, len(loaded_data))
|
||||
|
||||
def testPrepareProcessingGraph(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
wav_dir = os.path.join(tmp_dir, "wavs")
|
||||
os.mkdir(wav_dir)
|
||||
self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
|
||||
background_dir = os.path.join(wav_dir, "_background_noise_")
|
||||
os.mkdir(background_dir)
|
||||
wav_data = self._getWavData()
|
||||
for i in range(10):
|
||||
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
|
||||
self._saveTestWavFile(file_path, wav_data)
|
||||
model_settings = {
|
||||
"desired_samples": 160,
|
||||
"fingerprint_size": 40,
|
||||
"label_count": 4,
|
||||
"window_size_samples": 100,
|
||||
"window_stride_samples": 100,
|
||||
"dct_coefficient_count": 40,
|
||||
}
|
||||
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
|
||||
10, 10, model_settings)
|
||||
self.assertIsNotNone(audio_processor.wav_filename_placeholder_)
|
||||
self.assertIsNotNone(audio_processor.foreground_volume_placeholder_)
|
||||
self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_)
|
||||
self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_)
|
||||
self.assertIsNotNone(audio_processor.background_data_placeholder_)
|
||||
self.assertIsNotNone(audio_processor.background_volume_placeholder_)
|
||||
self.assertIsNotNone(audio_processor.mfcc_)
|
||||
|
||||
def testGetData(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
wav_dir = os.path.join(tmp_dir, "wavs")
|
||||
os.mkdir(wav_dir)
|
||||
self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
|
||||
background_dir = os.path.join(wav_dir, "_background_noise_")
|
||||
os.mkdir(background_dir)
|
||||
wav_data = self._getWavData()
|
||||
for i in range(10):
|
||||
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
|
||||
self._saveTestWavFile(file_path, wav_data)
|
||||
model_settings = {
|
||||
"desired_samples": 160,
|
||||
"fingerprint_size": 40,
|
||||
"label_count": 4,
|
||||
"window_size_samples": 100,
|
||||
"window_stride_samples": 100,
|
||||
"dct_coefficient_count": 40,
|
||||
}
|
||||
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
|
||||
10, 10, model_settings)
|
||||
with self.test_session() as sess:
|
||||
result_data, result_labels = audio_processor.get_data(
|
||||
10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
|
||||
self.assertEqual(10, len(result_data))
|
||||
self.assertEqual(10, len(result_labels))
|
||||
|
||||
def testGetUnprocessedData(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
wav_dir = os.path.join(tmp_dir, "wavs")
|
||||
os.mkdir(wav_dir)
|
||||
self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
|
||||
model_settings = {
|
||||
"desired_samples": 160,
|
||||
"fingerprint_size": 40,
|
||||
"label_count": 4,
|
||||
"window_size_samples": 100,
|
||||
"window_stride_samples": 100,
|
||||
"dct_coefficient_count": 40,
|
||||
}
|
||||
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
|
||||
10, 10, model_settings)
|
||||
result_data, result_labels = audio_processor.get_unprocessed_data(
|
||||
10, model_settings, "training")
|
||||
self.assertEqual(10, len(result_data))
|
||||
self.assertEqual(10, len(result_labels))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
176
tensorflow/examples/speech_commands/label_wav.cc
Normal file
176
tensorflow/examples/speech_commands/label_wav.cc
Normal file
@ -0,0 +1,176 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
// These are all common classes it's handy to reference with no namespace.
|
||||
using tensorflow::Flag;
|
||||
using tensorflow::Status;
|
||||
using tensorflow::Tensor;
|
||||
using tensorflow::int32;
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
// Reads a model graph definition from disk, and creates a session object you
|
||||
// can use to run it.
|
||||
Status LoadGraph(const string& graph_file_name,
|
||||
std::unique_ptr<tensorflow::Session>* session) {
|
||||
tensorflow::GraphDef graph_def;
|
||||
Status load_graph_status =
|
||||
ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
|
||||
if (!load_graph_status.ok()) {
|
||||
return tensorflow::errors::NotFound("Failed to load compute graph at '",
|
||||
graph_file_name, "'");
|
||||
}
|
||||
session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||
Status session_create_status = (*session)->Create(graph_def);
|
||||
if (!session_create_status.ok()) {
|
||||
return session_create_status;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Takes a file name, and loads a list of labels from it, one per line, and
|
||||
// returns a vector of the strings.
|
||||
Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
|
||||
std::ifstream file(file_name);
|
||||
if (!file) {
|
||||
return tensorflow::errors::NotFound("Labels file ", file_name,
|
||||
" not found.");
|
||||
}
|
||||
result->clear();
|
||||
string line;
|
||||
while (std::getline(file, line)) {
|
||||
result->push_back(line);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Analyzes the output of the graph to retrieve the highest scores and
|
||||
// their positions in the tensor.
|
||||
void GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
|
||||
Tensor* out_indices, Tensor* out_scores) {
|
||||
const Tensor& unsorted_scores_tensor = outputs[0];
|
||||
auto unsorted_scores_flat = unsorted_scores_tensor.flat<float>();
|
||||
std::vector<std::pair<int, float>> scores;
|
||||
scores.reserve(unsorted_scores_flat.size());
|
||||
for (int i = 0; i < unsorted_scores_flat.size(); ++i) {
|
||||
scores.push_back(std::pair<int, float>({i, unsorted_scores_flat(i)}));
|
||||
}
|
||||
std::sort(scores.begin(), scores.end(),
|
||||
[](const std::pair<int, float>& left,
|
||||
const std::pair<int, float>& right) {
|
||||
return left.second > right.second;
|
||||
});
|
||||
scores.resize(how_many_labels);
|
||||
Tensor sorted_indices(tensorflow::DT_INT32, {how_many_labels});
|
||||
Tensor sorted_scores(tensorflow::DT_FLOAT, {how_many_labels});
|
||||
for (int i = 0; i < scores.size(); ++i) {
|
||||
sorted_indices.flat<int>()(i) = scores[i].first;
|
||||
sorted_scores.flat<float>()(i) = scores[i].second;
|
||||
}
|
||||
*out_indices = sorted_indices;
|
||||
*out_scores = sorted_scores;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
string wav = "";
|
||||
string graph = "";
|
||||
string labels = "";
|
||||
string input_name = "wav_data";
|
||||
string output_name = "labels_softmax";
|
||||
int32 how_many_labels = 3;
|
||||
std::vector<Flag> flag_list = {
|
||||
Flag("wav", &wav, "audio file to be identified"),
|
||||
Flag("graph", &graph, "model to be executed"),
|
||||
Flag("labels", &labels, "path to file containing labels"),
|
||||
Flag("input_name", &input_name, "name of input node in model"),
|
||||
Flag("output_name", &output_name, "name of output node in model"),
|
||||
Flag("how_many_labels", &how_many_labels, "number of results to show"),
|
||||
};
|
||||
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
if (!parse_result) {
|
||||
LOG(ERROR) << usage;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// We need to call this to set up global state for TensorFlow.
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
if (argc > 1) {
|
||||
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// First we load and initialize the model.
|
||||
std::unique_ptr<tensorflow::Session> session;
|
||||
Status load_graph_status = LoadGraph(graph, &session);
|
||||
if (!load_graph_status.ok()) {
|
||||
LOG(ERROR) << load_graph_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<string> labels_list;
|
||||
Status read_labels_status = ReadLabelsFile(labels, &labels_list);
|
||||
if (!read_labels_status.ok()) {
|
||||
LOG(ERROR) << read_labels_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
string wav_string;
|
||||
Status read_wav_status = tensorflow::ReadFileToString(
|
||||
tensorflow::Env::Default(), wav, &wav_string);
|
||||
if (!read_wav_status.ok()) {
|
||||
LOG(ERROR) << read_wav_status;
|
||||
return -1;
|
||||
}
|
||||
Tensor wav_tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
|
||||
wav_tensor.scalar<string>()() = wav_string;
|
||||
|
||||
// Actually run the audio through the model.
|
||||
std::vector<Tensor> outputs;
|
||||
Status run_status =
|
||||
session->Run({{input_name, wav_tensor}}, {output_name}, {}, &outputs);
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Running model failed: " << run_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Tensor indices;
|
||||
Tensor scores;
|
||||
GetTopLabels(outputs, how_many_labels, &indices, &scores);
|
||||
tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
|
||||
tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
|
||||
for (int pos = 0; pos < how_many_labels; ++pos) {
|
||||
const int label_index = indices_flat(pos);
|
||||
const float score = scores_flat(pos);
|
||||
LOG(INFO) << labels_list[label_index] << " (" << label_index
|
||||
<< "): " << score;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
133
tensorflow/examples/speech_commands/label_wav.py
Normal file
133
tensorflow/examples/speech_commands/label_wav.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Runs a trained audio graph against a WAVE file and reports the results.
|
||||
|
||||
The model, labels and .wav file specified in the arguments will be loaded, and
|
||||
then the predictions from running the model against the audio data will be
|
||||
printed to the console. This is a useful script for sanity checking trained
|
||||
models, and as an example of how to use an audio model from Python.
|
||||
|
||||
Here's an example of running it:
|
||||
|
||||
python tensorflow/examples/speech_commands/label_wav.py \
|
||||
--graph=/tmp/my_frozen_graph.pb \
|
||||
--labels=/tmp/speech_commands_train/conv_labels.txt \
|
||||
--wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||
# pylint: enable=unused-import
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def load_graph(filename):
|
||||
"""Unpersists graph from file as default graph."""
|
||||
with tf.gfile.FastGFile(filename, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
|
||||
def load_labels(filename):
|
||||
"""Read in labels, one label per line."""
|
||||
return [line.rstrip() for line in tf.gfile.GFile(filename)]
|
||||
|
||||
|
||||
def run_graph(wav_data, labels, input_layer_name, output_layer_name,
|
||||
num_top_predictions):
|
||||
"""Runs the audio data through the graph and prints predictions."""
|
||||
with tf.Session() as sess:
|
||||
# Feed the audio data as input to the graph.
|
||||
# predictions will contain a two-dimensional array, where one
|
||||
# dimension represents the input image count, and the other has
|
||||
# predictions per class
|
||||
softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
|
||||
predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})
|
||||
|
||||
# Sort to show labels in order of confidence
|
||||
top_k = predictions.argsort()[-num_top_predictions:][::-1]
|
||||
for node_id in top_k:
|
||||
human_string = labels[node_id]
|
||||
score = predictions[node_id]
|
||||
print('%s (score = %.5f)' % (human_string, score))
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def label_wav(wav, labels, graph, input_name, output_name, how_many_labels):
|
||||
"""Loads the model and labels, and runs the inference to print predictions."""
|
||||
if not wav or not tf.gfile.Exists(wav):
|
||||
tf.logging.fatal('Audio file does not exist %s', wav)
|
||||
|
||||
if not labels or not tf.gfile.Exists(labels):
|
||||
tf.logging.fatal('Labels file does not exist %s', labels)
|
||||
|
||||
if not graph or not tf.gfile.Exists(graph):
|
||||
tf.logging.fatal('Graph file does not exist %s', graph)
|
||||
|
||||
labels_list = load_labels(labels)
|
||||
|
||||
# load graph, which is stored in the default session
|
||||
load_graph(graph)
|
||||
|
||||
with open(wav, 'rb') as wav_file:
|
||||
wav_data = wav_file.read()
|
||||
|
||||
run_graph(wav_data, labels_list, input_name, output_name, how_many_labels)
|
||||
|
||||
|
||||
def main(_):
|
||||
"""Entry point for script, converts flags to arguments."""
|
||||
label_wav(FLAGS.wav, FLAGS.labels, FLAGS.graph, FLAGS.input_name,
|
||||
FLAGS.output_name, FLAGS.how_many_labels)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--wav', type=str, default='', help='Audio file to be identified.')
|
||||
parser.add_argument(
|
||||
'--graph', type=str, default='', help='Model to use for identification.')
|
||||
parser.add_argument(
|
||||
'--labels', type=str, default='', help='Path to file containing labels.')
|
||||
parser.add_argument(
|
||||
'--input_name',
|
||||
type=str,
|
||||
default='wav_data:0',
|
||||
help='Name of WAVE data input node in model.')
|
||||
parser.add_argument(
|
||||
'--output_name',
|
||||
type=str,
|
||||
default='labels_softmax:0',
|
||||
help='Name of node outputting a prediction in the model.')
|
||||
parser.add_argument(
|
||||
'--how_many_labels',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of results to show.')
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
64
tensorflow/examples/speech_commands/label_wav_test.py
Normal file
64
tensorflow/examples/speech_commands/label_wav_test.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for WAVE file labeling tool."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||
from tensorflow.examples.speech_commands import label_wav
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class LabelWavTest(test.TestCase):
|
||||
|
||||
def _getWavData(self):
|
||||
with self.test_session() as sess:
|
||||
sample_data = tf.zeros([1000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
with open(filename, "wb") as f:
|
||||
f.write(wav_data)
|
||||
|
||||
def testLabelWav(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
wav_data = self._getWavData()
|
||||
wav_filename = os.path.join(tmp_dir, "wav_file.wav")
|
||||
self._saveTestWavFile(wav_filename, wav_data)
|
||||
input_name = "test_input"
|
||||
output_name = "test_output"
|
||||
graph_filename = os.path.join(tmp_dir, "test_graph.pb")
|
||||
with tf.Session() as sess:
|
||||
tf.placeholder(tf.string, name=input_name)
|
||||
tf.zeros([1, 3], name=output_name)
|
||||
with open(graph_filename, "wb") as f:
|
||||
f.write(sess.graph.as_graph_def().SerializeToString())
|
||||
labels_filename = os.path.join(tmp_dir, "test_labels.txt")
|
||||
with open(labels_filename, "w") as f:
|
||||
f.write("a\nb\nc\n")
|
||||
label_wav.label_wav(wav_filename, labels_filename, graph_filename,
|
||||
input_name + ":0", output_name + ":0", 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
378
tensorflow/examples/speech_commands/models.py
Normal file
378
tensorflow/examples/speech_commands/models.py
Normal file
@ -0,0 +1,378 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Model definitions for simple speech recognition.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
|
||||
window_size_ms, window_stride_ms,
|
||||
dct_coefficient_count):
|
||||
"""Calculates common settings needed for all models.
|
||||
|
||||
Args:
|
||||
label_count: How many classes are to be recognized.
|
||||
sample_rate: Number of audio samples per second.
|
||||
clip_duration_ms: Length of each audio clip to be analyzed.
|
||||
window_size_ms: Duration of frequency analysis window.
|
||||
window_stride_ms: How far to move in time between frequency windows.
|
||||
dct_coefficient_count: Number of frequency bins to use for analysis.
|
||||
|
||||
Returns:
|
||||
Dictionary containing common settings.
|
||||
"""
|
||||
desired_samples = int(sample_rate * clip_duration_ms / 1000)
|
||||
window_size_samples = int(sample_rate * window_size_ms / 1000)
|
||||
window_stride_samples = int(sample_rate * window_stride_ms / 1000)
|
||||
length_minus_window = (desired_samples - window_size_samples)
|
||||
if length_minus_window < 0:
|
||||
spectrogram_length = 0
|
||||
else:
|
||||
spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
|
||||
fingerprint_size = dct_coefficient_count * spectrogram_length
|
||||
return {
|
||||
'desired_samples': desired_samples,
|
||||
'window_size_samples': window_size_samples,
|
||||
'window_stride_samples': window_stride_samples,
|
||||
'spectrogram_length': spectrogram_length,
|
||||
'dct_coefficient_count': dct_coefficient_count,
|
||||
'fingerprint_size': fingerprint_size,
|
||||
'label_count': label_count,
|
||||
'sample_rate': sample_rate,
|
||||
}
|
||||
|
||||
|
||||
def create_model(fingerprint_input, model_settings, model_architecture,
|
||||
is_training):
|
||||
"""Builds a model of the requested architecture compatible with the settings.
|
||||
|
||||
There are many possible ways of deriving predictions from a spectrogram
|
||||
input, so this function provides an abstract interface for creating different
|
||||
kinds of models in a black-box way. You need to pass in a TensorFlow node as
|
||||
the 'fingerprint' input, and this should output a batch of 1D features that
|
||||
describe the audio. Typically this will be derived from a spectrogram that's
|
||||
been run through an MFCC, but in theory it can be any feature vector of the
|
||||
size specified in model_settings['fingerprint_size'].
|
||||
|
||||
The function will build the graph it needs in the current TensorFlow graph,
|
||||
and return the tensorflow output that will contain the 'logits' input to the
|
||||
softmax prediction process. If training flag is on, it will also return a
|
||||
placeholder node that can be used to control the dropout amount.
|
||||
|
||||
See the implementations below for the possible model architectures that can be
|
||||
requested.
|
||||
|
||||
Args:
|
||||
fingerprint_input: TensorFlow node that will output audio feature vectors.
|
||||
model_settings: Dictionary of information about the model.
|
||||
model_architecture: String specifying which kind of model to create.
|
||||
is_training: Whether the model is going to be used for training.
|
||||
|
||||
Returns:
|
||||
TensorFlow node outputting logits results, and optionally a dropout
|
||||
placeholder.
|
||||
|
||||
Raises:
|
||||
Exception: If the architecture type isn't recognized.
|
||||
"""
|
||||
if model_architecture == 'single_fc':
|
||||
return create_single_fc_model(fingerprint_input, model_settings,
|
||||
is_training)
|
||||
elif model_architecture == 'conv':
|
||||
return create_conv_model(fingerprint_input, model_settings, is_training)
|
||||
elif model_architecture == 'low_latency_conv':
|
||||
return create_low_latency_conv_model(fingerprint_input, model_settings,
|
||||
is_training)
|
||||
else:
|
||||
raise Exception('model_architecture argument "' + model_architecture +
|
||||
'" not recognized, should be one of "single_fc", "conv",' +
|
||||
' or "low_latency_conv"')
|
||||
|
||||
|
||||
def load_variables_from_checkpoint(sess, start_checkpoint):
|
||||
"""Utility function to centralize checkpoint restoration.
|
||||
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
start_checkpoint: Path to saved checkpoint on disk.
|
||||
"""
|
||||
saver = tf.train.Saver(tf.global_variables())
|
||||
saver.restore(sess, start_checkpoint)
|
||||
|
||||
|
||||
def create_single_fc_model(fingerprint_input, model_settings, is_training):
|
||||
"""Builds a model with a single hidden fully-connected layer.
|
||||
|
||||
This is a very simple model with just one matmul and bias layer. As you'd
|
||||
expect, it doesn't produce very accurate results, but it is very fast and
|
||||
simple, so it's useful for sanity testing.
|
||||
|
||||
Here's the layout of the graph:
|
||||
|
||||
(fingerprint_input)
|
||||
v
|
||||
[MatMul]<-(weights)
|
||||
v
|
||||
[BiasAdd]<-(bias)
|
||||
v
|
||||
|
||||
Args:
|
||||
fingerprint_input: TensorFlow node that will output audio feature vectors.
|
||||
model_settings: Dictionary of information about the model.
|
||||
is_training: Whether the model is going to be used for training.
|
||||
|
||||
Returns:
|
||||
TensorFlow node outputting logits results, and optionally a dropout
|
||||
placeholder.
|
||||
"""
|
||||
if is_training:
|
||||
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
|
||||
fingerprint_size = model_settings['fingerprint_size']
|
||||
label_count = model_settings['label_count']
|
||||
weights = tf.Variable(
|
||||
tf.truncated_normal([fingerprint_size, label_count], stddev=0.001))
|
||||
bias = tf.Variable(tf.zeros([label_count]))
|
||||
logits = tf.matmul(fingerprint_input, weights) + bias
|
||||
if is_training:
|
||||
return logits, dropout_prob
|
||||
else:
|
||||
return logits
|
||||
|
||||
|
||||
def create_conv_model(fingerprint_input, model_settings, is_training):
|
||||
"""Builds a standard convolutional model.
|
||||
|
||||
This is roughly the network labeled as 'cnn-trad-fpool3' in the
|
||||
'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
|
||||
http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
|
||||
|
||||
Here's the layout of the graph:
|
||||
|
||||
(fingerprint_input)
|
||||
v
|
||||
[Conv2D]<-(weights)
|
||||
v
|
||||
[BiasAdd]<-(bias)
|
||||
v
|
||||
[Relu]
|
||||
v
|
||||
[MaxPool]
|
||||
v
|
||||
[Conv2D]<-(weights)
|
||||
v
|
||||
[BiasAdd]<-(bias)
|
||||
v
|
||||
[Relu]
|
||||
v
|
||||
[MaxPool]
|
||||
v
|
||||
[MatMul]<-(weights)
|
||||
v
|
||||
[BiasAdd]<-(bias)
|
||||
v
|
||||
|
||||
This produces fairly good quality results, but can involve a large number of
|
||||
weight parameters and computations. For a cheaper alternative from the same
|
||||
paper with slightly less accuracy, see 'low_latency_conv' below.
|
||||
|
||||
During training, dropout nodes are introduced after each relu, controlled by a
|
||||
placeholder.
|
||||
|
||||
Args:
|
||||
fingerprint_input: TensorFlow node that will output audio feature vectors.
|
||||
model_settings: Dictionary of information about the model.
|
||||
is_training: Whether the model is going to be used for training.
|
||||
|
||||
Returns:
|
||||
TensorFlow node outputting logits results, and optionally a dropout
|
||||
placeholder.
|
||||
"""
|
||||
if is_training:
|
||||
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
|
||||
input_frequency_size = model_settings['dct_coefficient_count']
|
||||
input_time_size = model_settings['spectrogram_length']
|
||||
fingerprint_4d = tf.reshape(fingerprint_input,
|
||||
[-1, input_time_size, input_frequency_size, 1])
|
||||
first_filter_width = 8
|
||||
first_filter_height = 20
|
||||
first_filter_count = 64
|
||||
first_weights = tf.Variable(
|
||||
tf.truncated_normal(
|
||||
[first_filter_height, first_filter_width, 1, first_filter_count],
|
||||
stddev=0.01))
|
||||
first_bias = tf.Variable(tf.zeros([first_filter_count]))
|
||||
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1],
|
||||
'SAME') + first_bias
|
||||
first_relu = tf.nn.relu(first_conv)
|
||||
if is_training:
|
||||
first_dropout = tf.nn.dropout(first_relu, dropout_prob)
|
||||
else:
|
||||
first_dropout = first_relu
|
||||
max_pool = tf.nn.max_pool(first_dropout, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
|
||||
second_filter_width = 4
|
||||
second_filter_height = 10
|
||||
second_filter_count = 64
|
||||
second_weights = tf.Variable(
|
||||
tf.truncated_normal(
|
||||
[
|
||||
second_filter_height, second_filter_width, first_filter_count,
|
||||
second_filter_count
|
||||
],
|
||||
stddev=0.01))
|
||||
second_bias = tf.Variable(tf.zeros([second_filter_count]))
|
||||
second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1],
|
||||
'SAME') + second_bias
|
||||
second_relu = tf.nn.relu(second_conv)
|
||||
if is_training:
|
||||
second_dropout = tf.nn.dropout(second_relu, dropout_prob)
|
||||
else:
|
||||
second_dropout = second_relu
|
||||
second_conv_shape = second_dropout.get_shape()
|
||||
second_conv_output_width = second_conv_shape[2]
|
||||
second_conv_output_height = second_conv_shape[1]
|
||||
second_conv_element_count = int(
|
||||
second_conv_output_width * second_conv_output_height *
|
||||
second_filter_count)
|
||||
flattened_second_conv = tf.reshape(second_dropout,
|
||||
[-1, second_conv_element_count])
|
||||
label_count = model_settings['label_count']
|
||||
final_fc_weights = tf.Variable(
|
||||
tf.truncated_normal(
|
||||
[second_conv_element_count, label_count], stddev=0.01))
|
||||
final_fc_bias = tf.Variable(tf.zeros([label_count]))
|
||||
final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
|
||||
if is_training:
|
||||
return final_fc, dropout_prob
|
||||
else:
|
||||
return final_fc
|
||||
|
||||
|
||||
def create_low_latency_conv_model(fingerprint_input, model_settings,
|
||||
is_training):
|
||||
"""Builds a convolutional model with low compute requirements.
|
||||
|
||||
This is roughly the network labeled as 'cnn-one-fstride4' in the
|
||||
'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
|
||||
http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
|
||||
|
||||
Here's the layout of the graph:
|
||||
|
||||
(fingerprint_input)
|
||||
v
|
||||
[Conv2D]<-(weights)
|
||||
v
|
||||
[BiasAdd]<-(bias)
|
||||
v
|
||||
[Relu]
|
||||
v
|
||||
[MatMul]<-(weights)
|
||||
v
|
||||
[BiasAdd]<-(bias)
|
||||
v
|
||||
[MatMul]<-(weights)
|
||||
v
|
||||
[BiasAdd]<-(bias)
|
||||
v
|
||||
[MatMul]<-(weights)
|
||||
v
|
||||
[BiasAdd]<-(bias)
|
||||
v
|
||||
|
||||
This produces slightly lower quality results than the 'conv' model, but needs
|
||||
fewer weight parameters and computations.
|
||||
|
||||
During training, dropout nodes are introduced after the relu, controlled by a
|
||||
placeholder.
|
||||
|
||||
Args:
|
||||
fingerprint_input: TensorFlow node that will output audio feature vectors.
|
||||
model_settings: Dictionary of information about the model.
|
||||
is_training: Whether the model is going to be used for training.
|
||||
|
||||
Returns:
|
||||
TensorFlow node outputting logits results, and optionally a dropout
|
||||
placeholder.
|
||||
"""
|
||||
if is_training:
|
||||
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
|
||||
input_frequency_size = model_settings['dct_coefficient_count']
|
||||
input_time_size = model_settings['spectrogram_length']
|
||||
fingerprint_4d = tf.reshape(fingerprint_input,
|
||||
[-1, input_time_size, input_frequency_size, 1])
|
||||
first_filter_width = 8
|
||||
first_filter_height = input_time_size
|
||||
first_filter_count = 186
|
||||
first_filter_stride_x = 1
|
||||
first_filter_stride_y = 4
|
||||
first_weights = tf.Variable(
|
||||
tf.truncated_normal(
|
||||
[first_filter_height, first_filter_width, 1, first_filter_count],
|
||||
stddev=0.01))
|
||||
first_bias = tf.Variable(tf.zeros([first_filter_count]))
|
||||
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [
|
||||
1, first_filter_stride_y, first_filter_stride_x, 1
|
||||
], 'VALID') + first_bias
|
||||
first_relu = tf.nn.relu(first_conv)
|
||||
if is_training:
|
||||
first_dropout = tf.nn.dropout(first_relu, dropout_prob)
|
||||
else:
|
||||
first_dropout = first_relu
|
||||
first_conv_output_width = math.floor(
|
||||
(input_frequency_size - first_filter_width + first_filter_stride_x) /
|
||||
first_filter_stride_x)
|
||||
first_conv_output_height = math.floor(
|
||||
(input_time_size - first_filter_height + first_filter_stride_y) /
|
||||
first_filter_stride_y)
|
||||
first_conv_element_count = int(
|
||||
first_conv_output_width * first_conv_output_height * first_filter_count)
|
||||
flattened_first_conv = tf.reshape(first_dropout,
|
||||
[-1, first_conv_element_count])
|
||||
first_fc_output_channels = 128
|
||||
first_fc_weights = tf.Variable(
|
||||
tf.truncated_normal(
|
||||
[first_conv_element_count, first_fc_output_channels], stddev=0.01))
|
||||
first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
|
||||
first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
|
||||
if is_training:
|
||||
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
|
||||
else:
|
||||
second_fc_input = first_fc
|
||||
second_fc_output_channels = 128
|
||||
second_fc_weights = tf.Variable(
|
||||
tf.truncated_normal(
|
||||
[first_fc_output_channels, second_fc_output_channels], stddev=0.01))
|
||||
second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
|
||||
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
|
||||
if is_training:
|
||||
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
|
||||
else:
|
||||
final_fc_input = second_fc
|
||||
label_count = model_settings['label_count']
|
||||
final_fc_weights = tf.Variable(
|
||||
tf.truncated_normal(
|
||||
[second_fc_output_channels, label_count], stddev=0.01))
|
||||
final_fc_bias = tf.Variable(tf.zeros([label_count]))
|
||||
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
|
||||
if is_training:
|
||||
return final_fc, dropout_prob
|
||||
else:
|
||||
return final_fc
|
86
tensorflow/examples/speech_commands/models_test.py
Normal file
86
tensorflow/examples/speech_commands/models_test.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for speech commands models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.examples.speech_commands import models
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ModelsTest(test.TestCase):
|
||||
|
||||
def testPrepareModelSettings(self):
|
||||
self.assertIsNotNone(
|
||||
models.prepare_model_settings(10, 16000, 1000, 20, 10, 40))
|
||||
|
||||
def testCreateModelConvTraining(self):
|
||||
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
|
||||
with self.test_session() as sess:
|
||||
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
|
||||
logits, dropout_prob = models.create_model(fingerprint_input,
|
||||
model_settings, "conv", True)
|
||||
self.assertIsNotNone(logits)
|
||||
self.assertIsNotNone(dropout_prob)
|
||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
|
||||
|
||||
def testCreateModelConvInference(self):
|
||||
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
|
||||
with self.test_session() as sess:
|
||||
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
|
||||
logits = models.create_model(fingerprint_input, model_settings, "conv",
|
||||
False)
|
||||
self.assertIsNotNone(logits)
|
||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
||||
|
||||
def testCreateModelLowLatencyConvTraining(self):
|
||||
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
|
||||
with self.test_session() as sess:
|
||||
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
|
||||
logits, dropout_prob = models.create_model(
|
||||
fingerprint_input, model_settings, "low_latency_conv", True)
|
||||
self.assertIsNotNone(logits)
|
||||
self.assertIsNotNone(dropout_prob)
|
||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
|
||||
|
||||
def testCreateModelFullyConnectedTraining(self):
|
||||
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
|
||||
with self.test_session() as sess:
|
||||
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
|
||||
logits, dropout_prob = models.create_model(
|
||||
fingerprint_input, model_settings, "single_fc", True)
|
||||
self.assertIsNotNone(logits)
|
||||
self.assertIsNotNone(dropout_prob)
|
||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
|
||||
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
|
||||
|
||||
def testCreateModelBadArchitecture(self):
|
||||
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
|
||||
with self.test_session():
|
||||
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
|
||||
with self.assertRaises(Exception) as e:
|
||||
models.create_model(fingerprint_input, model_settings,
|
||||
"bad_architecture", True)
|
||||
self.assertTrue("not recognized" in str(e.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
127
tensorflow/examples/speech_commands/recognize_commands.cc
Normal file
127
tensorflow/examples/speech_commands/recognize_commands.cc
Normal file
@ -0,0 +1,127 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/examples/speech_commands/recognize_commands.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
RecognizeCommands::RecognizeCommands(const std::vector<string>& labels,
|
||||
int32 average_window_duration_ms,
|
||||
float detection_threshold,
|
||||
int32 suppression_ms, int32 minimum_count)
|
||||
: labels_(labels),
|
||||
average_window_duration_ms_(average_window_duration_ms),
|
||||
detection_threshold_(detection_threshold),
|
||||
suppression_ms_(suppression_ms),
|
||||
minimum_count_(minimum_count) {
|
||||
labels_count_ = labels.size();
|
||||
previous_top_label_ = "_silence_";
|
||||
previous_top_label_time_ = std::numeric_limits<int64>::min();
|
||||
}
|
||||
|
||||
Status RecognizeCommands::ProcessLatestResults(const Tensor& latest_results,
|
||||
const int64 current_time_ms,
|
||||
string* found_command,
|
||||
float* score,
|
||||
bool* is_new_command) {
|
||||
if (latest_results.NumElements() != labels_count_) {
|
||||
return errors::InvalidArgument(
|
||||
"The results for recognition should contain ", labels_count_,
|
||||
" elements, but there are ", latest_results.NumElements());
|
||||
}
|
||||
|
||||
if ((!previous_results_.empty()) &&
|
||||
(current_time_ms < previous_results_.front().first)) {
|
||||
return errors::InvalidArgument(
|
||||
"Results must be fed in increasing time order, but received a "
|
||||
"timestamp of ",
|
||||
current_time_ms, " that was earlier than the previous one of ",
|
||||
previous_results_.front().first);
|
||||
}
|
||||
|
||||
// Add the latest results to the head of the queue.
|
||||
previous_results_.push_back({current_time_ms, latest_results});
|
||||
|
||||
// Prune any earlier results that are too old for the averaging window.
|
||||
const int64 time_limit = current_time_ms - average_window_duration_ms_;
|
||||
while (previous_results_.front().first < time_limit) {
|
||||
previous_results_.pop_front();
|
||||
}
|
||||
|
||||
// If there are too few results, assume the result will be unreliable and
|
||||
// bail.
|
||||
const int64 how_many_results = previous_results_.size();
|
||||
const int64 earliest_time = previous_results_.front().first;
|
||||
const int64 samples_duration = current_time_ms - earliest_time;
|
||||
if ((how_many_results < minimum_count_) ||
|
||||
(samples_duration < (average_window_duration_ms_ / 4))) {
|
||||
*found_command = previous_top_label_;
|
||||
*score = 0.0f;
|
||||
*is_new_command = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Calculate the average score across all the results in the window.
|
||||
std::vector<float> average_scores(labels_count_);
|
||||
for (const auto& previous_result : previous_results_) {
|
||||
const Tensor& scores_tensor = previous_result.second;
|
||||
auto scores_flat = scores_tensor.flat<float>();
|
||||
for (int i = 0; i < scores_flat.size(); ++i) {
|
||||
average_scores[i] += scores_flat(i) / how_many_results;
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the averaged results in descending score order.
|
||||
std::vector<std::pair<int, float>> sorted_average_scores;
|
||||
sorted_average_scores.reserve(labels_count_);
|
||||
for (int i = 0; i < labels_count_; ++i) {
|
||||
sorted_average_scores.push_back(
|
||||
std::pair<int, float>({i, average_scores[i]}));
|
||||
}
|
||||
std::sort(sorted_average_scores.begin(), sorted_average_scores.end(),
|
||||
[](const std::pair<int, float>& left,
|
||||
const std::pair<int, float>& right) {
|
||||
return left.second > right.second;
|
||||
});
|
||||
|
||||
// See if the latest top score is enough to trigger a detection.
|
||||
const int current_top_index = sorted_average_scores[0].first;
|
||||
const string current_top_label = labels_[current_top_index];
|
||||
const float current_top_score = sorted_average_scores[0].second;
|
||||
// If we've recently had another label trigger, assume one that occurs too
|
||||
// soon afterwards is a bad result.
|
||||
int64 time_since_last_top;
|
||||
if ((previous_top_label_ == "_silence_") ||
|
||||
(previous_top_label_time_ == std::numeric_limits<int64>::min())) {
|
||||
time_since_last_top = std::numeric_limits<int64>::max();
|
||||
} else {
|
||||
time_since_last_top = current_time_ms - previous_top_label_time_;
|
||||
}
|
||||
if ((current_top_score > detection_threshold_) &&
|
||||
(current_top_label != previous_top_label_) &&
|
||||
(time_since_last_top > suppression_ms_)) {
|
||||
previous_top_label_ = current_top_label;
|
||||
previous_top_label_time_ = current_time_ms;
|
||||
*is_new_command = true;
|
||||
} else {
|
||||
*is_new_command = false;
|
||||
}
|
||||
*found_command = current_top_label;
|
||||
*score = current_top_score;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
79
tensorflow/examples/speech_commands/recognize_commands.h
Normal file
79
tensorflow/examples/speech_commands/recognize_commands.h
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This class is designed to apply a very primitive decoding model on top of the
|
||||
// instantaneous results from running an audio recognition model on a single
|
||||
// window of samples. It applies smoothing over time so that noisy individual
|
||||
// label scores are averaged, increasing the confidence that apparent matches
|
||||
// are real.
|
||||
// To use it, you should create a class object with the configuration you
|
||||
// want, and then feed results from running a TensorFlow model into the
|
||||
// processing method. The timestamp for each subsequent call should be
|
||||
// increasing from the previous, since the class is designed to process a stream
|
||||
// of data over time.
|
||||
class RecognizeCommands {
|
||||
public:
|
||||
// labels should be a list of the strings associated with each one-hot score.
|
||||
// The window duration controls the smoothing. Longer durations will give a
|
||||
// higher confidence that the results are correct, but may miss some commands.
|
||||
// The detection threshold has a similar effect, with high values increasing
|
||||
// the precision at the cost of recall. The minimum count controls how many
|
||||
// results need to be in the averaging window before it's seen as a reliable
|
||||
// average. This prevents erroneous results when the averaging window is
|
||||
// initially being populated for example. The suppression argument disables
|
||||
// further recognitions for a set time after one has been triggered, which can
|
||||
// help reduce spurious recognitions.
|
||||
explicit RecognizeCommands(const std::vector<string>& labels,
|
||||
int32 average_window_duration_ms = 1000,
|
||||
float detection_threshold = 0.2,
|
||||
int32 suppression_ms = 500,
|
||||
int32 minimum_count = 3);
|
||||
|
||||
// Call this with the results of running a model on sample data.
|
||||
Status ProcessLatestResults(const Tensor& latest_results,
|
||||
const int64 current_time_ms,
|
||||
string* found_command, float* score,
|
||||
bool* is_new_command);
|
||||
|
||||
private:
|
||||
// Configuration
|
||||
std::vector<string> labels_;
|
||||
int32 average_window_duration_ms_;
|
||||
float detection_threshold_;
|
||||
int32 suppression_ms_;
|
||||
int32 minimum_count_;
|
||||
|
||||
// Working variables
|
||||
std::deque<std::pair<int64, Tensor>> previous_results_;
|
||||
string previous_top_label_;
|
||||
int64 labels_count_;
|
||||
int64 previous_top_label_time_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
|
114
tensorflow/examples/speech_commands/recognize_commands_test.cc
Normal file
114
tensorflow/examples/speech_commands/recognize_commands_test.cc
Normal file
@ -0,0 +1,114 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/examples/speech_commands/recognize_commands.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(RecognizeCommandsTest, Basic) {
|
||||
RecognizeCommands recognize_commands({"_silence_", "a", "b"});
|
||||
|
||||
Tensor results(DT_FLOAT, {3});
|
||||
test::FillValues<float>(&results, {1.0f, 0.0f, 0.0f});
|
||||
|
||||
string found_command;
|
||||
float score;
|
||||
bool is_new_command;
|
||||
TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
|
||||
results, 0, &found_command, &score, &is_new_command));
|
||||
}
|
||||
|
||||
TEST(RecognizeCommandsTest, FindCommands) {
|
||||
RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f);
|
||||
|
||||
Tensor results(DT_FLOAT, {3});
|
||||
|
||||
test::FillValues<float>(&results, {0.0f, 1.0f, 0.0f});
|
||||
bool has_found_new_command = false;
|
||||
string new_command;
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
string found_command;
|
||||
float score;
|
||||
bool is_new_command;
|
||||
int64 current_time_ms = 0 + (i * 100);
|
||||
TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
|
||||
results, current_time_ms, &found_command, &score, &is_new_command));
|
||||
if (is_new_command) {
|
||||
EXPECT_FALSE(has_found_new_command);
|
||||
has_found_new_command = true;
|
||||
new_command = found_command;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(has_found_new_command);
|
||||
EXPECT_EQ("a", new_command);
|
||||
|
||||
test::FillValues<float>(&results, {0.0f, 0.0f, 1.0f});
|
||||
has_found_new_command = false;
|
||||
new_command = "";
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
string found_command;
|
||||
float score;
|
||||
bool is_new_command;
|
||||
int64 current_time_ms = 1000 + (i * 100);
|
||||
TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
|
||||
results, current_time_ms, &found_command, &score, &is_new_command));
|
||||
if (is_new_command) {
|
||||
EXPECT_FALSE(has_found_new_command);
|
||||
has_found_new_command = true;
|
||||
new_command = found_command;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(has_found_new_command);
|
||||
EXPECT_EQ("b", new_command);
|
||||
}
|
||||
|
||||
TEST(RecognizeCommandsTest, BadInputLength) {
|
||||
RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f);
|
||||
|
||||
Tensor bad_results(DT_FLOAT, {2});
|
||||
test::FillValues<float>(&bad_results, {1.0f, 0.0f});
|
||||
|
||||
string found_command;
|
||||
float score;
|
||||
bool is_new_command;
|
||||
EXPECT_FALSE(recognize_commands
|
||||
.ProcessLatestResults(bad_results, 0, &found_command, &score,
|
||||
&is_new_command)
|
||||
.ok());
|
||||
}
|
||||
|
||||
TEST(RecognizeCommandsTest, BadInputTimes) {
|
||||
RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f);
|
||||
|
||||
Tensor results(DT_FLOAT, {3});
|
||||
test::FillValues<float>(&results, {1.0f, 0.0f, 0.0f});
|
||||
|
||||
string found_command;
|
||||
float score;
|
||||
bool is_new_command;
|
||||
TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
|
||||
results, 100, &found_command, &score, &is_new_command));
|
||||
EXPECT_FALSE(recognize_commands
|
||||
.ProcessLatestResults(results, 0, &found_command, &score,
|
||||
&is_new_command)
|
||||
.ok());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
310
tensorflow/examples/speech_commands/test_streaming_accuracy.cc
Normal file
310
tensorflow/examples/speech_commands/test_streaming_accuracy.cc
Normal file
@ -0,0 +1,310 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
/*
|
||||
|
||||
Tool to create accuracy statistics from running an audio recognition model on a
|
||||
continuous stream of samples.
|
||||
|
||||
This is designed to be an environment for running experiments on new models and
|
||||
settings to understand the effects they will have in a real application. You
|
||||
need to supply it with a long audio file containing sounds you want to recognize
|
||||
and a text file listing the labels of each sound along with the time they occur.
|
||||
With this information, and a frozen model, the tool will process the audio
|
||||
stream, apply the model, and keep track of how many mistakes and successes the
|
||||
model achieved.
|
||||
|
||||
The matched percentage is the number of sounds that were correctly classified,
|
||||
as a percentage of the total number of sounds listed in the ground truth file.
|
||||
A correct classification is when the right label is chosen within a short time
|
||||
of the expected ground truth, where the time tolerance is controlled by the
|
||||
'time_tolerance_ms' command line flag.
|
||||
|
||||
The wrong percentage is how many sounds triggered a detection (the classifier
|
||||
figured out it wasn't silence or background noise), but the detected class was
|
||||
wrong. This is also a percentage of the total number of ground truth sounds.
|
||||
|
||||
The false positive percentage is how many sounds were detected when there was
|
||||
only silence or background noise. This is also expressed as a percentage of the
|
||||
total number of ground truth sounds, though since it can be large it may go
|
||||
above 100%.
|
||||
|
||||
The easiest way to get an audio file and labels to test with is by using the
|
||||
'generate_streaming_test_wav' script. This will synthesize a test file with
|
||||
randomly placed sounds and background noise, and output a text file with the
|
||||
ground truth.
|
||||
|
||||
If you want to test natural data, you need to use a .wav with the same sample
|
||||
rate as your model (often 16,000 samples per second), and note down where the
|
||||
sounds occur in time. Save this information out as a comma-separated text file,
|
||||
where the first column is the label and the second is the time in seconds from
|
||||
the start of the file that it occurs.
|
||||
|
||||
Here's an example of how to run the tool:
|
||||
|
||||
bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \
|
||||
--wav=/tmp/streaming_test_bg.wav \
|
||||
--graph=/tmp/conv_frozen.pb \
|
||||
--labels=/tmp/speech_commands_train/conv_labels.txt \
|
||||
--ground_truth=/tmp/streaming_test_labels.txt --verbose \
|
||||
--clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \
|
||||
--suppression_ms=500 --time_tolerance_ms=1500
|
||||
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/wav/wav_io.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/examples/speech_commands/accuracy_utils.h"
|
||||
#include "tensorflow/examples/speech_commands/recognize_commands.h"
|
||||
|
||||
// These are all common classes it's handy to reference with no namespace.
|
||||
using tensorflow::Flag;
|
||||
using tensorflow::Status;
|
||||
using tensorflow::Tensor;
|
||||
using tensorflow::int32;
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
||||
// Reads a model graph definition from disk, and creates a session object you
|
||||
// can use to run it.
|
||||
Status LoadGraph(const string& graph_file_name,
|
||||
std::unique_ptr<tensorflow::Session>* session) {
|
||||
tensorflow::GraphDef graph_def;
|
||||
Status load_graph_status =
|
||||
ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
|
||||
if (!load_graph_status.ok()) {
|
||||
return tensorflow::errors::NotFound("Failed to load compute graph at '",
|
||||
graph_file_name, "'");
|
||||
}
|
||||
session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||
Status session_create_status = (*session)->Create(graph_def);
|
||||
if (!session_create_status.ok()) {
|
||||
return session_create_status;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Takes a file name, and loads a list of labels from it, one per line, and
|
||||
// returns a vector of the strings.
|
||||
Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
|
||||
std::ifstream file(file_name);
|
||||
if (!file) {
|
||||
return tensorflow::errors::NotFound("Labels file '", file_name,
|
||||
"' not found.");
|
||||
}
|
||||
result->clear();
|
||||
string line;
|
||||
while (std::getline(file, line)) {
|
||||
result->push_back(line);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
string wav = "";
|
||||
string graph = "";
|
||||
string labels = "";
|
||||
string ground_truth = "";
|
||||
string input_data_name = "decoded_sample_data:0";
|
||||
string input_rate_name = "decoded_sample_data:1";
|
||||
string output_name = "labels_softmax";
|
||||
int32 clip_duration_ms = 1000;
|
||||
int32 sample_stride_ms = 30;
|
||||
int32 average_window_ms = 500;
|
||||
int32 time_tolerance_ms = 750;
|
||||
int32 suppression_ms = 1500;
|
||||
float detection_threshold = 0.7f;
|
||||
bool verbose = false;
|
||||
std::vector<Flag> flag_list = {
|
||||
Flag("wav", &wav, "audio file to be identified"),
|
||||
Flag("graph", &graph, "model to be executed"),
|
||||
Flag("labels", &labels, "path to file containing labels"),
|
||||
Flag("ground_truth", &ground_truth,
|
||||
"path to file containing correct times and labels of words in the "
|
||||
"audio as <word>,<timestamp in ms> lines"),
|
||||
Flag("input_data_name", &input_data_name,
|
||||
"name of input data node in model"),
|
||||
Flag("input_rate_name", &input_rate_name,
|
||||
"name of input sample rate node in model"),
|
||||
Flag("output_name", &output_name, "name of output node in model"),
|
||||
Flag("clip_duration_ms", &clip_duration_ms,
|
||||
"length of recognition window"),
|
||||
Flag("average_window_ms", &average_window_ms,
|
||||
"length of window to smooth results over"),
|
||||
Flag("time_tolerance_ms", &time_tolerance_ms,
|
||||
"maximum gap allowed between a recognition and ground truth"),
|
||||
Flag("suppression_ms", &suppression_ms,
|
||||
"how long to ignore others for after a recognition"),
|
||||
Flag("sample_stride_ms", &sample_stride_ms,
|
||||
"how often to run recognition"),
|
||||
Flag("detection_threshold", &detection_threshold,
|
||||
"what score is required to trigger detection of a word"),
|
||||
Flag("verbose", &verbose, "whether to log extra debugging information"),
|
||||
};
|
||||
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
if (!parse_result) {
|
||||
LOG(ERROR) << usage;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// We need to call this to set up global state for TensorFlow.
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
if (argc > 1) {
|
||||
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// First we load and initialize the model.
|
||||
std::unique_ptr<tensorflow::Session> session;
|
||||
Status load_graph_status = LoadGraph(graph, &session);
|
||||
if (!load_graph_status.ok()) {
|
||||
LOG(ERROR) << load_graph_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<string> labels_list;
|
||||
Status read_labels_status = ReadLabelsFile(labels, &labels_list);
|
||||
if (!read_labels_status.ok()) {
|
||||
LOG(ERROR) << read_labels_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::pair<string, int64>> ground_truth_list;
|
||||
Status read_ground_truth_status =
|
||||
tensorflow::ReadGroundTruthFile(ground_truth, &ground_truth_list);
|
||||
if (!read_ground_truth_status.ok()) {
|
||||
LOG(ERROR) << read_ground_truth_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
string wav_string;
|
||||
Status read_wav_status = tensorflow::ReadFileToString(
|
||||
tensorflow::Env::Default(), wav, &wav_string);
|
||||
if (!read_wav_status.ok()) {
|
||||
LOG(ERROR) << read_wav_status;
|
||||
return -1;
|
||||
}
|
||||
std::vector<float> audio_data;
|
||||
uint32 sample_count;
|
||||
uint16 channel_count;
|
||||
uint32 sample_rate;
|
||||
Status decode_wav_status = tensorflow::wav::DecodeLin16WaveAsFloatVector(
|
||||
wav_string, &audio_data, &sample_count, &channel_count, &sample_rate);
|
||||
if (!decode_wav_status.ok()) {
|
||||
LOG(ERROR) << decode_wav_status;
|
||||
return -1;
|
||||
}
|
||||
if (channel_count != 1) {
|
||||
LOG(ERROR) << "Only mono .wav files can be used, but input has "
|
||||
<< channel_count << " channels.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000;
|
||||
const int64 sample_stride_samples = (sample_stride_ms * sample_rate) / 1000;
|
||||
Tensor audio_data_tensor(tensorflow::DT_FLOAT,
|
||||
tensorflow::TensorShape({clip_duration_samples, 1}));
|
||||
|
||||
Tensor sample_rate_tensor(tensorflow::DT_INT32, tensorflow::TensorShape({}));
|
||||
sample_rate_tensor.scalar<int32>()() = sample_rate;
|
||||
|
||||
tensorflow::RecognizeCommands recognize_commands(
|
||||
labels_list, average_window_ms, detection_threshold, suppression_ms);
|
||||
|
||||
std::vector<std::pair<string, int64>> all_found_words;
|
||||
tensorflow::StreamingAccuracyStats previous_stats;
|
||||
|
||||
const int64 audio_data_end = (sample_count - clip_duration_ms);
|
||||
for (int64 audio_data_offset = 0; audio_data_offset < audio_data_end;
|
||||
audio_data_offset += sample_stride_samples) {
|
||||
const float* input_start = &(audio_data[audio_data_offset]);
|
||||
const float* input_end = input_start + clip_duration_samples;
|
||||
std::copy(input_start, input_end, audio_data_tensor.flat<float>().data());
|
||||
|
||||
// Actually run the audio through the model.
|
||||
std::vector<Tensor> outputs;
|
||||
Status run_status = session->Run({{input_data_name, audio_data_tensor},
|
||||
{input_rate_name, sample_rate_tensor}},
|
||||
{output_name}, {}, &outputs);
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Running model failed: " << run_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const int64 current_time_ms = (audio_data_offset * 1000) / sample_rate;
|
||||
string found_command;
|
||||
float score;
|
||||
bool is_new_command;
|
||||
Status recognize_status = recognize_commands.ProcessLatestResults(
|
||||
outputs[0], current_time_ms, &found_command, &score, &is_new_command);
|
||||
if (!recognize_status.ok()) {
|
||||
LOG(ERROR) << "Recognition processing failed: " << recognize_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (is_new_command && (found_command != "_silence_")) {
|
||||
all_found_words.push_back({found_command, current_time_ms});
|
||||
if (verbose) {
|
||||
tensorflow::StreamingAccuracyStats stats;
|
||||
tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words,
|
||||
current_time_ms, time_tolerance_ms,
|
||||
&stats);
|
||||
int32 false_positive_delta = stats.how_many_false_positives -
|
||||
previous_stats.how_many_false_positives;
|
||||
int32 correct_delta = stats.how_many_correct_words -
|
||||
previous_stats.how_many_correct_words;
|
||||
int32 wrong_delta =
|
||||
stats.how_many_wrong_words - previous_stats.how_many_wrong_words;
|
||||
string recognition_state;
|
||||
if (false_positive_delta == 1) {
|
||||
recognition_state = " (False Positive)";
|
||||
} else if (correct_delta == 1) {
|
||||
recognition_state = " (Correct)";
|
||||
} else if (wrong_delta == 1) {
|
||||
recognition_state = " (Wrong)";
|
||||
} else {
|
||||
LOG(ERROR) << "Unexpected state in statistics";
|
||||
}
|
||||
LOG(INFO) << current_time_ms << "ms: " << found_command << ": " << score
|
||||
<< recognition_state;
|
||||
previous_stats = stats;
|
||||
tensorflow::PrintAccuracyStats(stats);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::StreamingAccuracyStats stats;
|
||||
tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, -1,
|
||||
time_tolerance_ms, &stats);
|
||||
tensorflow::PrintAccuracyStats(stats);
|
||||
|
||||
return 0;
|
||||
}
|
427
tensorflow/examples/speech_commands/train.py
Normal file
427
tensorflow/examples/speech_commands/train.py
Normal file
@ -0,0 +1,427 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Simple speech recognition to spot a limited number of keywords.
|
||||
|
||||
This is a self-contained example script that will train a very basic audio
|
||||
recognition model in TensorFlow. It can download the necessary training data,
|
||||
and runs with reasonable defaults to train within a few hours even only using a
|
||||
CPU. For more information see http://tensorflow.org/tutorials/audio_recognition.
|
||||
|
||||
It is intended as an introduction to using neural networks for audio
|
||||
recognition, and is not a full speech recognition system. For more advanced
|
||||
speech systems, I recommend looking into Kaldi. This network uses a keyword
|
||||
detection style to spot discrete words from a small vocabulary, consisting of
|
||||
"yes", "no", "up", "down", "left", "right", "on", "off", "stop", and "go".
|
||||
|
||||
To run the training process, use:
|
||||
|
||||
bazel run tensorflow/examples/speech_commands:train
|
||||
|
||||
This will write out checkpoints to /tmp/speech_commands_train/, and will
|
||||
download over 1GB of open source training data, so you'll need enough free space
|
||||
and a good internet connection. The default data is a collection of thousands of
|
||||
one-second .wav files, each containing one spoken word. This data set is
|
||||
collected from https://aiyprojects.withgoogle.com/open_speech_recording, please
|
||||
consider contributing to help improve this and other models!
|
||||
|
||||
As training progresses, it will print out its accuracy metrics, which should
|
||||
rise above 90% by the end. Once it's complete, you can run the freeze script to
|
||||
get a binary GraphDef that you can easily deploy on mobile applications.
|
||||
|
||||
If you want to train on your own data, you'll need to create .wavs with your
|
||||
recordings, all at a consistent length, and then arrange them into subfolders
|
||||
organized by label. For example, here's a possible file structure:
|
||||
|
||||
my_wavs >
|
||||
up >
|
||||
audio_0.wav
|
||||
audio_1.wav
|
||||
down >
|
||||
audio_2.wav
|
||||
audio_3.wav
|
||||
other>
|
||||
audio_4.wav
|
||||
audio_5.wav
|
||||
|
||||
You'll also need to tell the script what labels to look for, using the
|
||||
`--wanted_words` argument. In this case, 'up,down' might be what you want, and
|
||||
the audio in the 'other' folder would be used to train an 'unknown' category.
|
||||
|
||||
To pull this all together, you'd run:
|
||||
|
||||
bazel run tensorflow/examples/speech_commands:train -- \
|
||||
--data_dir=my_wavs --wanted_words=up,down
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import input_data
|
||||
import models
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def main(_):
|
||||
# We want to see all the logging messages for this tutorial.
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
# Start a new TensorFlow session.
|
||||
sess = tf.InteractiveSession()
|
||||
|
||||
# Begin by making sure we have the training data we need. If you already have
|
||||
# training data of your own, use `--data_url= ` on the command line to avoid
|
||||
# downloading.
|
||||
model_settings = models.prepare_model_settings(
|
||||
len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
|
||||
FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
|
||||
FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
|
||||
audio_processor = input_data.AudioProcessor(
|
||||
FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage,
|
||||
FLAGS.unknown_percentage,
|
||||
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
|
||||
FLAGS.testing_percentage, model_settings)
|
||||
fingerprint_size = model_settings['fingerprint_size']
|
||||
label_count = model_settings['label_count']
|
||||
time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
|
||||
# Figure out the learning rates for each training phase. Since it's often
|
||||
# effective to have high learning rates at the start of training, followed by
|
||||
# lower levels towards the end, the number of steps and learning rates can be
|
||||
# specified as comma-separated lists to define the rate at each stage. For
|
||||
# example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
|
||||
# will run 13,000 training loops in total, with a rate of 0.001 for the first
|
||||
# 10,000, and 0.0001 for the final 3,000.
|
||||
training_steps_list = map(int, FLAGS.how_many_training_steps.split(','))
|
||||
learning_rates_list = map(float, FLAGS.learning_rate.split(','))
|
||||
if len(training_steps_list) != len(learning_rates_list):
|
||||
raise Exception(
|
||||
'--how_many_training_steps and --learning_rate must be equal length '
|
||||
'lists, but are %d and %d long instead' % (len(training_steps_list),
|
||||
len(learning_rates_list)))
|
||||
|
||||
fingerprint_input = tf.placeholder(
|
||||
tf.float32, [None, fingerprint_size], name='fingerprint_input')
|
||||
|
||||
logits, dropout_prob = models.create_model(
|
||||
fingerprint_input,
|
||||
model_settings,
|
||||
FLAGS.model_architecture,
|
||||
is_training=True)
|
||||
|
||||
# Define loss and optimizer
|
||||
ground_truth_input = tf.placeholder(
|
||||
tf.float32, [None, label_count], name='groundtruth_input')
|
||||
|
||||
# Optionally we can add runtime checks to spot when NaNs or other symptoms of
|
||||
# numerical errors start occurring during training.
|
||||
control_dependencies = []
|
||||
if FLAGS.check_nans:
|
||||
checks = tf.add_check_numerics_ops()
|
||||
control_dependencies = [checks]
|
||||
|
||||
# Create the back propagation and training evaluation machinery in the graph.
|
||||
with tf.name_scope('cross_entropy'):
|
||||
cross_entropy_mean = tf.reduce_mean(
|
||||
tf.nn.softmax_cross_entropy_with_logits(
|
||||
labels=ground_truth_input, logits=logits))
|
||||
tf.summary.scalar('cross_entropy', cross_entropy_mean)
|
||||
with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
|
||||
learning_rate_input = tf.placeholder(
|
||||
tf.float32, [], name='learning_rate_input')
|
||||
train_step = tf.train.GradientDescentOptimizer(
|
||||
learning_rate_input).minimize(cross_entropy_mean)
|
||||
predicted_indices = tf.argmax(logits, 1)
|
||||
expected_indices = tf.argmax(ground_truth_input, 1)
|
||||
correct_prediction = tf.equal(predicted_indices, expected_indices)
|
||||
confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices)
|
||||
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||
tf.summary.scalar('accuracy', evaluation_step)
|
||||
|
||||
global_step = tf.contrib.framework.get_or_create_global_step()
|
||||
increment_global_step = tf.assign(global_step, global_step + 1)
|
||||
|
||||
saver = tf.train.Saver(tf.global_variables())
|
||||
|
||||
# Merge all the summaries and write them out to /tmp/retrain_logs (by default)
|
||||
merged_summaries = tf.summary.merge_all()
|
||||
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
|
||||
sess.graph)
|
||||
validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
|
||||
|
||||
tf.global_variables_initializer().run()
|
||||
|
||||
start_step = 1
|
||||
|
||||
if FLAGS.start_checkpoint:
|
||||
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
|
||||
start_step = global_step.eval(session=sess)
|
||||
|
||||
tf.logging.info('Training from step: %d ', start_step)
|
||||
|
||||
# Save graph.pbtxt.
|
||||
tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
|
||||
FLAGS.model_architecture + '.pbtxt')
|
||||
|
||||
# Save list of words.
|
||||
with gfile.GFile(
|
||||
os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'),
|
||||
'w') as f:
|
||||
f.write('\n'.join(audio_processor.words_list))
|
||||
|
||||
# Training loop.
|
||||
training_steps_max = np.sum(training_steps_list)
|
||||
for training_step in xrange(start_step, training_steps_max + 1):
|
||||
# Figure out what the current learning rate is.
|
||||
training_steps_sum = 0
|
||||
for i in range(len(training_steps_list)):
|
||||
training_steps_sum += training_steps_list[i]
|
||||
if training_step <= training_steps_sum:
|
||||
learning_rate_value = learning_rates_list[i]
|
||||
break
|
||||
# Pull the audio samples we'll use for training.
|
||||
train_fingerprints, train_ground_truth = audio_processor.get_data(
|
||||
FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
|
||||
FLAGS.background_volume, time_shift_samples, 'training', sess)
|
||||
# Run the graph with this batch of training data.
|
||||
train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
|
||||
[
|
||||
merged_summaries, evaluation_step, cross_entropy_mean, train_step,
|
||||
increment_global_step
|
||||
],
|
||||
feed_dict={
|
||||
fingerprint_input: train_fingerprints,
|
||||
ground_truth_input: train_ground_truth,
|
||||
learning_rate_input: learning_rate_value,
|
||||
dropout_prob: 0.5
|
||||
})
|
||||
train_writer.add_summary(train_summary, training_step)
|
||||
tf.logging.info('Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
|
||||
(training_step, learning_rate_value, train_accuracy * 100,
|
||||
cross_entropy_value))
|
||||
is_last_step = (training_step == training_steps_max)
|
||||
if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
|
||||
set_size = audio_processor.set_size('validation')
|
||||
total_accuracy = 0
|
||||
total_conf_matrix = None
|
||||
for i in xrange(0, set_size, FLAGS.batch_size):
|
||||
validation_fingerprints, validation_ground_truth = (
|
||||
audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0,
|
||||
0.0, 0, 'validation', sess))
|
||||
# Run a validation step and capture training summaries for TensorBoard
|
||||
# with the `merged` op.
|
||||
validation_summary, validation_accuracy, conf_matrix = sess.run(
|
||||
[merged_summaries, evaluation_step, confusion_matrix],
|
||||
feed_dict={
|
||||
fingerprint_input: validation_fingerprints,
|
||||
ground_truth_input: validation_ground_truth,
|
||||
dropout_prob: 1.0
|
||||
})
|
||||
validation_writer.add_summary(validation_summary, training_step)
|
||||
batch_size = min(FLAGS.batch_size, set_size - i)
|
||||
total_accuracy += (validation_accuracy * batch_size) / set_size
|
||||
if total_conf_matrix is None:
|
||||
total_conf_matrix = conf_matrix
|
||||
else:
|
||||
total_conf_matrix += conf_matrix
|
||||
tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
|
||||
tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
|
||||
(training_step, total_accuracy * 100, set_size))
|
||||
|
||||
# Save the model checkpoint periodically.
|
||||
if (training_step % FLAGS.save_step_interval == 0 or
|
||||
training_step == training_steps_max):
|
||||
checkpoint_path = os.path.join(FLAGS.train_dir,
|
||||
FLAGS.model_architecture + '.ckpt')
|
||||
tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step)
|
||||
saver.save(sess, checkpoint_path, global_step=training_step)
|
||||
|
||||
set_size = audio_processor.set_size('testing')
|
||||
tf.logging.info('set_size=%d', set_size)
|
||||
total_accuracy = 0
|
||||
total_conf_matrix = None
|
||||
for i in xrange(0, set_size, FLAGS.batch_size):
|
||||
test_fingerprints, test_ground_truth = audio_processor.get_data(
|
||||
FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
|
||||
test_accuracy, conf_matrix = sess.run(
|
||||
[evaluation_step, confusion_matrix],
|
||||
feed_dict={
|
||||
fingerprint_input: test_fingerprints,
|
||||
ground_truth_input: test_ground_truth,
|
||||
dropout_prob: 1.0
|
||||
})
|
||||
batch_size = min(FLAGS.batch_size, set_size - i)
|
||||
total_accuracy += (test_accuracy * batch_size) / set_size
|
||||
if total_conf_matrix is None:
|
||||
total_conf_matrix = conf_matrix
|
||||
else:
|
||||
total_conf_matrix += conf_matrix
|
||||
tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
|
||||
tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100,
|
||||
set_size))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--data_url',
|
||||
type=str,
|
||||
# pylint: disable=line-too-long
|
||||
default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
|
||||
# pylint: enable=line-too-long
|
||||
help='Location of speech training data archive on the web.')
|
||||
parser.add_argument(
|
||||
'--data_dir',
|
||||
type=str,
|
||||
default='/tmp/speech_dataset/',
|
||||
help="""\
|
||||
Where to download the speech training data to.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--background_volume',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="""\
|
||||
How loud the background noise should be, between 0 and 1.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--background_frequency',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="""\
|
||||
How many of the training samples have background noise mixed in.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--silence_percentage',
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="""\
|
||||
How much of the training data should be silence.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--unknown_percentage',
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="""\
|
||||
How much of the training data should be unknown words.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--time_shift_ms',
|
||||
type=float,
|
||||
default=100.0,
|
||||
help="""\
|
||||
Range to randomly shift the training audio by in time.
|
||||
""")
|
||||
parser.add_argument(
|
||||
'--testing_percentage',
|
||||
type=int,
|
||||
default=10,
|
||||
help='What percentage of wavs to use as a test set.')
|
||||
parser.add_argument(
|
||||
'--validation_percentage',
|
||||
type=int,
|
||||
default=10,
|
||||
help='What percentage of wavs to use as a validation set.')
|
||||
parser.add_argument(
|
||||
'--sample_rate',
|
||||
type=int,
|
||||
default=16000,
|
||||
help='Expected sample rate of the wavs',)
|
||||
parser.add_argument(
|
||||
'--clip_duration_ms',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Expected duration in milliseconds of the wavs',)
|
||||
parser.add_argument(
|
||||
'--window_size_ms',
|
||||
type=float,
|
||||
default=20.0,
|
||||
help='How long each spectrogram timeslice is',)
|
||||
parser.add_argument(
|
||||
'--window_stride_ms',
|
||||
type=float,
|
||||
default=10.0,
|
||||
help='How long each spectrogram timeslice is',)
|
||||
parser.add_argument(
|
||||
'--dct_coefficient_count',
|
||||
type=int,
|
||||
default=40,
|
||||
help='How many bins to use for the MFCC fingerprint',)
|
||||
parser.add_argument(
|
||||
'--how_many_training_steps',
|
||||
type=str,
|
||||
default='15000,3000',
|
||||
help='How many training loops to run',)
|
||||
parser.add_argument(
|
||||
'--eval_step_interval',
|
||||
type=int,
|
||||
default=400,
|
||||
help='How often to evaluate the training results.')
|
||||
parser.add_argument(
|
||||
'--learning_rate',
|
||||
type=str,
|
||||
default='0.001,0.0001',
|
||||
help='How large a learning rate to use when training.')
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=100,
|
||||
help='How many items to train with at once',)
|
||||
parser.add_argument(
|
||||
'--summaries_dir',
|
||||
type=str,
|
||||
default='/tmp/retrain_logs',
|
||||
help='Where to save summary logs for TensorBoard.')
|
||||
parser.add_argument(
|
||||
'--wanted_words',
|
||||
type=str,
|
||||
default='yes,no,up,down,left,right,on,off,stop,go',
|
||||
help='Words to use (others will be added to an unknown label)',)
|
||||
parser.add_argument(
|
||||
'--train_dir',
|
||||
type=str,
|
||||
default='/tmp/speech_commands_train',
|
||||
help='Directory to write event logs and checkpoint.')
|
||||
parser.add_argument(
|
||||
'--save_step_interval',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Save model checkpoint every save_steps.')
|
||||
parser.add_argument(
|
||||
'--start_checkpoint',
|
||||
type=str,
|
||||
default='',
|
||||
help='If specified, restore this pretrained model before any training.')
|
||||
parser.add_argument(
|
||||
'--model_architecture',
|
||||
type=str,
|
||||
default='conv',
|
||||
help='What model architecture to use')
|
||||
parser.add_argument(
|
||||
'--check_nans',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Whether to check for invalid numbers during processing')
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1133,6 +1133,15 @@ tf_gen_op_wrapper_private_py(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "audio_ops_gen",
|
||||
require_shape_functions = True,
|
||||
visibility = [
|
||||
"//learning/brain/python/ops:__pkg__",
|
||||
"//tensorflow/contrib/framework:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "candidate_sampling_ops_gen",
|
||||
visibility = ["//learning/brain/python/ops:__pkg__"],
|
||||
|
Loading…
Reference in New Issue
Block a user