Speech keyword detector tutorial

Adds a basic training script for a simple audio model to our examples.
See third_party/docs_src/tutorials/audio_recognition.md for full documentation

PiperOrigin-RevId: 165025732
This commit is contained in:
Pete Warden 2017-08-11 14:30:43 -07:00 committed by TensorFlower Gardener
parent e9a8d75bc4
commit 0c6fd1703e
28 changed files with 4321 additions and 13 deletions

View File

@ -370,6 +370,7 @@ filegroup(
"//tensorflow/examples/label_image:all_files",
"//tensorflow/examples/learn:all_files",
"//tensorflow/examples/saved_model:all_files",
"//tensorflow/examples/speech_commands:all_files",
"//tensorflow/examples/tutorials/estimators:all_files",
"//tensorflow/examples/tutorials/mnist:all_files",
"//tensorflow/examples/tutorials/word2vec:all_files",

View File

@ -28,6 +28,7 @@ tf_custom_op_py_library(
"python/framework/tensor_util.py",
"python/ops/__init__.py",
"python/ops/arg_scope.py",
"python/ops/audio_ops.py",
"python/ops/checkpoint_ops.py",
"python/ops/ops.py",
"python/ops/prettyprint_ops.py",
@ -50,6 +51,7 @@ tf_custom_op_py_library(
":gen_variable_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:audio_ops_gen",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",

View File

@ -0,0 +1,36 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""Audio processing and decoding ops.
@@decode_wav
@@encode_wav
@@audio_spectrogram
@@mfcc
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_audio_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__, [])

View File

@ -118,6 +118,17 @@ Status ReadValue(const string& data, T* value, int* offset) {
return Status::OK();
}
Status ReadString(const string& data, int expected_length, string* value,
int* offset) {
const int new_offset = *offset + expected_length;
if (new_offset > data.size()) {
return errors::InvalidArgument("Data too short when trying to read string");
}
*value = string(data.begin() + *offset, data.begin() + new_offset);
*offset = new_offset;
return Status::OK();
}
} // namespace
Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
@ -254,17 +265,33 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string,
// Skip over this unused section.
offset += 2;
}
TF_RETURN_IF_ERROR(ExpectText(wav_string, kDataChunkId, &offset));
uint32 data_size;
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &data_size, &offset));
*sample_count = data_size / bytes_per_sample;
const uint32 data_count = *sample_count * *channel_count;
float_values->resize(data_count);
for (int i = 0; i < data_count; ++i) {
int16 single_channel_value = 0;
TF_RETURN_IF_ERROR(
ReadValue<int16>(wav_string, &single_channel_value, &offset));
(*float_values)[i] = Int16SampleToFloat(single_channel_value);
bool was_data_found = false;
while (offset < wav_string.size()) {
string chunk_id;
TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
uint32 chunk_size;
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
if (chunk_id == kDataChunkId) {
if (was_data_found) {
return errors::InvalidArgument("More than one data chunk found in WAV");
}
was_data_found = true;
*sample_count = chunk_size / bytes_per_sample;
const uint32 data_count = *sample_count * *channel_count;
float_values->resize(data_count);
for (int i = 0; i < data_count; ++i) {
int16 single_channel_value = 0;
TF_RETURN_IF_ERROR(
ReadValue<int16>(wav_string, &single_channel_value, &offset));
(*float_values)[i] = Int16SampleToFloat(single_channel_value);
}
} else {
offset += chunk_size;
}
}
if (!was_data_found) {
return errors::InvalidArgument("No data chunk found in WAV");
}
return Status::OK();
}

View File

@ -62,7 +62,7 @@ Status DecodeWavShapeFn(InferenceContext* c) {
Status EncodeWavShapeFn(InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
c->set_output(0, c->Scalar());
return Status::OK();
}
@ -104,7 +104,7 @@ Status MfccShapeFn(InferenceContext* c) {
ShapeHandle spectrogram;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram));
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
int32 dct_coefficient_count;
TF_RETURN_IF_ERROR(

View File

@ -0,0 +1,551 @@
# How to Train a Simple Audio Recognition Network
This tutorial will show you how to build a basic speech recognition network that
recognizes ten different words. It's important to know that real speech and
audio recognition systems are much more complex, but like MNIST for images, it
should give you a basic understanding of the techniques involved. Once you've
completed this tutorial, you'll have a model that tries to classify a one second
audio clip as either silence, an unknown word, "yes", "no", "up", "down",
"left", "right", "on", "off", "stop", or "go". You'll also be able to take this
model and run it in an Android application.
## Preparation
You should make sure you have TensorFlow installed, and since the script
downloads over 1GB of training data, you'll need a good internet connection and
enough free space on your machine. The training process itself can take several
hours, so make sure you have a machine available for that long.
## Training
To begin the training process, go to the TensorFlow source tree and run:
```bash
python tensorflow/examples/speech_commands/train.py
```
The script will start off by downloading the [Speech Commands
dataset](https://download.tensorflow.org/data/speech_commands_v0.01.tar.gz),
which consists of 65,000 WAVE audio files of people saying thirty different
words. This data was collected by Google and released under a CC BY license, and
you can help improve it by [contributing five minutes of your own
voice](https://aiyprojects.withgoogle.com/open_speech_recording). The archive is
over 1GB, so this part may take a while, but you should see progress logs, and
once it's been downloaded once you won't need to do this step again.
Once the downloading has completed, you'll see logging information that looks
like this:
```
I0730 16:53:44.766740 55030 train.py:176] Training from step: 1
I0730 16:53:47.289078 55030 train.py:217] Step #1: rate 0.001000, accuracy 7.0%, cross entropy 2.611571
```
This shows that the initialization process is done and the training loop has
begun. You'll see that it outputs information for every training step. Here's a
break down of what it means:
`Step #1` shows that we're on the first step of the training loop. In this case
there are going to be 18,000 steps in total, so you can look at the step number
to get an idea of how close it is to finishing.
`rate 0.001000` is the learning rate that's controlling the speed of the
network's weight updates. Early on this is a comparatively high number (0.001),
but for later training cycles it will be reduced 10x, to 0.0001.
`accuracy 7.0%` is the how many classes were correctly predicted on this
training step. This value will often fluctuate a lot, but should increase on
average as training progresses. The model outputs an array of numbers, one for
each label, and each number is the predicted likelihood of the input being that
class. The predicted label is picked by choosing the entry with the highest
score. The scores are always between zero and one, with higher values
representing more confidence in the result.
`cross entropy 2.611571` is the result of the loss function that we're using to
guide the training process. This is a score that's obtained by comparing the
vector of scores from the current training run to the correct labels, and this
should trend downwards during training.
After a hundred steps, you should see a line like this:
`I0730 16:54:41.813438 55030 train.py:252] Saving to
"/tmp/speech_commands_train/conv.ckpt-100"`
This is saving out the current trained weights to a checkpoint file. If your
training script gets interrupted, you can look for the last saved checkpoint and
then restart the script with
`--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-100` as a command line
argument to start from that point.
## Confusion Matrix
After four hundred steps, this information will be logged:
```
I0730 16:57:38.073667 55030 train.py:243] Confusion Matrix:
[[258 0 0 0 0 0 0 0 0 0 0 0]
[ 7 6 26 94 7 49 1 15 40 2 0 11]
[ 10 1 107 80 13 22 0 13 10 1 0 4]
[ 1 3 16 163 6 48 0 5 10 1 0 17]
[ 15 1 17 114 55 13 0 9 22 5 0 9]
[ 1 1 6 97 3 87 1 12 46 0 0 10]
[ 8 6 86 84 13 24 1 9 9 1 0 6]
[ 9 3 32 112 9 26 1 36 19 0 0 9]
[ 8 2 12 94 9 52 0 6 72 0 0 2]
[ 16 1 39 74 29 42 0 6 37 9 0 3]
[ 15 6 17 71 50 37 0 6 32 2 1 9]
[ 11 1 6 151 5 42 0 8 16 0 0 20]]
```
The first section is a [confusion
matrix](https://www.tensorflow.org/api_docs/python/tf/confusion_matrix). To
understand what it means, you first need to know the labels being used, which in
this case are "_silence_", "_unknown_", "yes", "no", "up", "down", "left",
"right", "on", "off", "stop", and "go". Each column represents a set of samples
that were predicted to be each label, so the first column represents all the
clips that were predicted to be silence, the second all those that were
predicted to be unknown words, the third "yes", and so on.
Each row represents clips by their correct, ground truth labels. The first row
is all the clips that were silence, the second clips that were unknown words,
the third "yes", etc.
This matrix can be more useful than just a single accuracy score because it
gives a good summary of what mistakes the network is making. In this example you
can see that all of the entries in the first row are zero, apart from the
initial one. Because the first row is all the clips that are actually silence,
this means that none of them were mistakenly labeled as words, so we have no
false negatives for silence. This shows the network is already getting pretty
good at distinguishing silence from words.
If we look down the first column though, we see a lot of non-zero values. The
column represents all the clips that were predicted to be silence, so positive
numbers outside of the first cell are errors. This means that some clips of real
spoken words are actually being predicted to be silence, so we do have quite a
few false positives.
A perfect model would produce a confusion matrix where all of the entries were
zero apart from a diagonal line through the center. Spotting deviations from
that pattern can help you figure out how the model is most easily confused, and
once you've identified the problems you can address them by adding more data or
cleaning up categories.
## Validation
After the confusion matrix, you should see a line like this:
`I0730 16:57:38.073777 55030 train.py:245] Step 400: Validation accuracy = 26.3%
(N=3093)`
It's good practice to separate your data set into three categories. The largest
(in this case roughly 80% of the data) is used for training the network, a
smaller set (10% here, known as "validation") is reserved for evaluation of the
accuracy during training, and another set (the last 10%, "testing") is used to
evaluate the accuracy once after the training is complete.
The reason for this split is that there's always a danger that networks will
start memorizing their inputs during training. By keeping the validation set
separate, you can ensure that the model works with data it's never seen before.
The testing set is an additional safeguard to make sure that you haven't just
been tweaking your model in a way that happens to work for both the training and
validation sets, but not a broader range of inputs.
The training script automatically separates the data set into these three
categories, and the logging line above shows the accuracy of model when run on
the validation set. Ideally, this should stick fairly close to the training
accuracy. If the training accuracy increases but the validation doesn't, that's
a sign that overfitting is occurring, and your model is only learning things
about the training clips, not broader patterns that generalize.
## Tensorboard
A good way to visualize how the training is progressing is using Tensorboard. By
default, the script saves out events to /tmp/retrain_logs, and you can load
these by running:
`tensorboard --logdir /tmp/retrain_logs`
Then navigate to [http://localhost:6006](http://localhost:6006) in your browser,
and you'll see charts and graphs showing your models progress.
<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://storage.googleapis.com/download.tensorflow.org/example_images/speech_commands_tensorflow.png"/>
</div>
## Training Finished
After a few hours of training (depending on your machine's speed), the script
should have completed all 18,000 steps. It will print out a final confusion
matrix, along with an accuracy score, all run on the testing set. With the
default settings, you should see an accuracy of between 85% and 90%.
Because audio recognition is particularly useful on mobile devices, next we'll
export it to a compact format that's easy to work with on those platforms. To do
that, run this command line:
```
python tensorflow/examples/speech_commands/freeze.py \
--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-18000 \
--output_file=/tmp/my_frozen_graph.pb
```
Once the frozen model has been created, you can test it with the `label_wav.py`
script, like this:
```
python tensorflow/examples/speech_commands/label_wav.py \
--graph=/tmp/my_frozen_graph.pb \
--labels=/tmp/speech_commands_train/conv_labels.txt \
--wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav
```
This should print out three labels:
```
left (score = 0.81477)
right (score = 0.14139)
_unknown_ (score = 0.03808)
```
Hopefully "left" is the top score since that's the correct label, but since the
training is random it may not for the first file you try. Experiment with some
of the other .wav files in that same folder to see how well it does.
The scores are between zero and one, and higher values mean the model is more
confident in its prediction.
## How does this Model Work?
The architecture used in this tutorial is based on some described in the paper
[Convolutional Neural Networks for Small-footprint Keyword
Spotting](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf).
It was chosen because it's comparatively simple, quick to train, and easy to
understand, rather than being state of the art. There are lots of different
approaches to building neural network models to work with audio, including
[recurrent networks](https://svds.com/tensorflow-rnn-tutorial/) or [dilated
(atrous)
convolutions](https://deepmind.com/blog/wavenet-generative-model-raw-audio/).
This tutorial is based on the kind of convolutional network that will feel very
familiar to anyone who's worked with image recognition. That may seem surprising
at first though, since audio is inherently a one-dimensional continuous signal
across time, not a 2D spatial problem.
We solve that issue by defining a window of time we believe our spoken words
should fit into, and converting the audio signal in that window into an image.
This is done by grouping the incoming audio samples into short segments, just a
few milliseconds long, and calculating the strength of the frequencies across a
set of bands. Each set of frequency strengths from a segment is treated as a
vector of numbers, and those vectors are arranged in time order to form a
two-dimensional array. This array of values can then be treated like a
single-channel image, and is known as a
[spectrogram](https://en.wikipedia.org/wiki/Spectrogram). If you want to view
what kind of image an audio sample produces, you can run the `wav_to_spectrogram
tool:
```
bazel run tensorflow/examples/wav_to_spectrogram:wav_to_spectrogram -- \
--input_wav=/tmp/speech_dataset/happy/ab00c4b2_nohash_0.wav \
--output_png=/tmp/spectrogram.png
```
If you open up `/tmp/spectrogram.png` you should see something like this:
<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://storage.googleapis.com/download.tensorflow.org/example_images/spectrogram.png"/>
</div>
Because of TensorFlow's memory order, time in this image is increasing from top
to bottom, with frequencies going from left to right, unlike the usual
convention for spectrograms where time is left to right. You should be able to
see a couple of distinct parts, with the first syllable "Ha" distinct from
"ppy".
Because the human ear is more sensitive to some frequencies than others, it's
been traditional in speech recognition to do further processing to this
representation to turn it into a set of [Mel-Frequency Cepstral
Coefficients](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum), or MFCCs
for short. This is also a two-dimensional, one-channel representation so it can
be treated like an image too. If you're targeting general sounds rather than
speech you may find you can skip this step and operate directly on the
spectrograms.
The image that's produced by these processing steps is then fed into a
multi-layer convolutional neural network, with a fully-connected layer followed
by a softmax at the end. You can see the definition of this portion in
[tensorflow/examples/speech_commands/models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py).
## Streaming Accuracy
Most audio recognition applications need to run on a continuous stream of audio,
rather than on individual clips. A typical way to use a model in this
environment is to apply it repeatedly at different offsets in time and average
the results over a short window to produce a smoothed prediction. If you think
of the input as an image, it's continuously scrolling along the time axis. The
words we want to recognize can start at any time, so we need to take a series of
snapshots to have a chance of having an alignment that captures most of the
utterance in the time window we feed into the model. If we sample at a high
enough rate, then we have a good chance of capturing the word in multiple
windows, so averaging the results improves the overall confidence of the
prediction.
For an example of how you can use your model on streaming data, you can look at
[test_streaming_accuracy.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/).
This uses the
[RecognizeCommands](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/recognize_commands.h)
class to run through a long-form input audio, try to spot words, and compare
those predictions against a ground truth list of labels and times. This makes it
a good example of applying a model to a stream of audio signals over time.
You'll need a long audio file to test it against, along with labels showing
where each word was spoken. If you don't want to record one yourself, you can
generate some synthetic test data using the `generate_streaming_test_wav`
utility. By default this will create a ten minute .wav file with words roughly
every three seconds, and a text file containing the ground truth of when each
word was spoken. These words are pulled from the test portion of your current
dataset, mixed in with background noise. To run it, use:
```
bazel run tensorflow/examples/speech_commands:generate_streaming_test_wav
```
This will save a .wav file to `/tmp/speech_commands_train/streaming_test.wav`,
and a text file listing the labels to
`/tmp/speech_commands_train/streaming_test_labels.txt`. You can then run
accuracy testing with:
```
bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \
--graph=/tmp/my_frozen_graph.pb \
--labels=/tmp/speech_commands_train/conv_labels.txt \
--wav=/tmp/speech_commands_train/streaming_test.wav \
--ground_truth=/tmp/speech_commands_train/streaming_test_labels.txt \
--verbose
```
This will output information about the number of words correctly matched, how
many were given the wrong labels, and how many times the model triggered when
there was no real word spoken. There are various parameters that control how the
signal averaging works, including `--average_window_ms` which sets the length of
time to average results over, `--sample_stride_ms` which is the time between
applications of the model, `--suppression_ms` which stops subsequent word
detections from triggering for a certain time after an initial one is found, and
`--detection_threshold`, which controls how high the average score must be
before it's considered a solid result.
You'll see that the streaming accuracy outputs three numbers, rather than just
the one metric used in training. This is because different applications have
varying requirements, with some being able to tolerate frequent incorrect
results as long as real words are found (high recall), while others very focused
on ensuring the predicted labels are highly likely to be correct even if some
aren't detected (high precision). The numbers from the tool give you an idea of
how your model will perform in an application, and you can try tweaking the
signal averaging parameters to tune it to give the kind of performance you want.
To understand what the right parameters are for your application, you can look
at generating an [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)
to help you understand the tradeoffs.
## RecognizeCommands
The streaming accuracy tool uses a simple decoder contained in a small
C++ class called
[RecognizeCommands](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/recognize_commands.h).
This class is fed the output of running the TensorFlow model over time, it
averages the signals, and returns information about a label when it has enough
evidence to think that a recognized word has been found. The implementation is
fairly small, just keeping track of the last few predictions and averaging them,
so it's easy to port to other platforms and languages as needed. For example,
it's convenient to do something similar at the Java level on Android, or Python
on the Raspberry Pi. As long as these implementations share the same logic, you
can tune the parameters that control the averaging using the streaming test
tool, and then transfer them over to your application to get similar results.
## Advanced Training
The defaults for the training script are designed to produce good end to end
results in a comparatively small file, but there are a lot of options you can
change to customize the results for your own requirements.
### Custom Training Data
By default the script will download the [Speech Commands
dataset](https://download.tensorflow.org/data/speech_commands_v0.01.tgz), but
you can also supply your own training data. To train on your own data, you
should make sure that you have at least several hundred recordings of each sound
you would like to recognize, and arrange them into folders by class. For
example, if you were trying to recognize dog barks from cat miaows, you would
create a root folder called `animal_sounds`, and then within that two
sub-folders called `bark` and `miaow`. You would then organize your audio files
into the appropriate folders.
To point the script to your new audio files, you'll need to set `--data_url=` to
disable downloading of the Speech Commands dataset, and
`--data_dir=/your/data/folder/` to find the files you've just created.
The files themselves should be 16-bit little-endian PCM-encoded WAVE format. The
sample rate defaults to 16,000, but as long as all your audio is consistently
the same rate (the script doesn't support resampling) you can change this with
the `--sample_rate` argument. The clips should also all be roughly the same
duration. The default expected duration is one second, but you can set this with
the `--clip_duration_ms` flag. If you have clips with variable amounts of
silence at the start, you can look at word alignment tools to standardize them
([here's a quick and dirty approach you can use
too](https://petewarden.com/2017/07/17/a-quick-hack-to-align-single-word-audio-recordings/)).
One issue to watch out for is that you may have very similar repetitions of the
same sounds in your dataset, and these can give misleading metrics if they're
spread across your training, validation, and test sets. For example, the Speech
Commands set has people repeating the same word multiple times. Each one of
those repetitions is likely to be pretty close to the others, so if training was
overfitting and memorizing one, it could perform unrealistically well when it
saw a very similar copy in the test set. To avoid this danger, Speech Commands
trys to ensure that all clips featuring the same word spoken by a single person
are put into the same partition. Clips are assigned to training, test, or
validation sets based on a hash of their filename, to ensure that the
assignments remain steady even as new clips are added and avoid any training
samples migrating into the other sets. To make sure that all a given speaker's
words are in the same bucket, [the hashing
function](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/input_data.py)
ignores anything in a filename after '_nohash_' when calculating the
assignments. This means that if you have file names like `pete_nohash_0.wav` and
`pete_nohash_1.wav`, they're guaranteed to be in the same set.
### Unknown Class
It's likely that your application will hear sounds that aren't in your training
set, and you'll want the model to indicate that it doesn't recognize the noise
in those cases. To help the network learn what sounds to ignore, you need to
provide some clips of audio that are neither of your classes. To do this, you'd
create `quack`, `oink`, and `moo` subfolders and populate them with noises from
other animals your users might encounter. The `--wanted_words` argument to the
script defines which classes you care about, all the others mentioned in
subfolder names will be used to populate an `_unknown_` class during training.
The Speech Commands dataset has twenty words in its unknown classes, including
the digits zero through nine and random names like "Sheila".
By default 10% of the training examples are picked from the unknown classes, but
you can control this with the `--unknown_percentage` flag. Increasing this will
make the model less likely to mistake unknown words for wanted ones, but making
it too large can backfire as the model might decide it's safest to categorize
all words as unknown!
### Background Noise
Real applications have to recognize audio even when there are other irrelevant
sounds happening in the environment. To build a model that's robust to this kind
of interference, we need to train against recorded audio with similar
properties. The files in the Speech Commands dataset were captured on a variety
of devices by users in many different environments, not in a studio, so that
helps add some realism to the training. To add even more, you can mix in random
segments of environmental audio to the training inputs. In the Speech Commands
set there's a special folder called `_background_noise_` which contains
minute-long WAVE files with white noise and recordings of machinery and everyday
household activity.
Small snippets of these files are chosen at random and mixed at a low volume
into clips during training. The loudness is also chosen randomly, and controlled
by the `--background_volume` argument as a proportion where 0 is silence, and 1
is full volume. Not all clips have background added, so the
`--background_frequency` flag controls what proportion have them mixed in.
Your own application might operate in its own environment with different
background noise patterns than these defaults, so you can supply your own audio
clips in the `_background_noise_` folder. These should be the same sample rate
as your main dataset, but much longer in duration so that a good set of random
segments can be selected from them.
### Silence
In most cases the sounds you care about will be intermittent and so it's
important to know when there's no matching audio. To support this, there's a
special `_silence_` label that indicates when the model detects nothing
interesting. Because there's never complete silence in real environments, we
actually have to supply examples with quiet and irrelevant audio. For this, we
reuse the `_background_noise_` folder that's also mixed in to real clips,
pulling short sections of the audio data and feeding those in with the ground
truth class of `_silence_`. By default 10% of the training data is supplied like
this, but the `--silence_percentage` can be used to control the proportion. As
with unknown words, setting this higher can weight the model results in favor of
true positives for silence, at the expense of false negatives for words, but too
large a proportion can cause it to fall into the trap of always guessing
silence.
### Time Shifting
Adding in background noise is one way of distorting the training data in a
realistic way to effectively increase the size of the dataset, and so increase
overall accuracy, and time shifting is another. This involves a random offset in
time of the training sample data, so that a small part of the start or end is
cut off and the opposite section is padded with zeroes. This mimics the natural
variations in starting time in the training data, and is controlled with the
`--time_shift_ms` flag, which defaults to 100ms. Increasing this value will
provide more variation, but at the risk of cutting off important parts of the
audio. A related way of augmenting the data with realistic distortions is by
using [time stretching and pitch scaling](https://en.wikipedia.org/wiki/Audio_time_stretching_and_pitch_scaling),
but that's outside the scope of this tutorial.
## Customizing the Model
The default model used for this script is pretty large, taking over 800 million
FLOPs for each inference and using 940,000 weight parameters. This runs at
usable speeds on desktop machines or modern phones, but it involves too many
calculations to run at interactive speeds on devices with more limited
resources. To support these use cases, there's an alternative model available,
based on the 'cnn-one-fstride4' architecture described in the [Convolutional
Neural Networks for Small-footprint Keyword Spotting
paper](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf).
The number of weight parameters is about the same, but it only needs 11 million
FLOPs to run one prediction, making it much faster.
To use this model, you can specify `--model_architecture=low_latency_conv` on
the command line. You'll also need to update the training rates and the number
of steps, so the full command will look like:
```
python tensorflow/examples/speech_commands/train \
--model_architecture=low_latency_conv \
--how_many_training_steps=20000,6000 \
--learning_rate=0.01,0.001
```
This asks the script to train with a learning rate of 0.01 for 20,000 steps, and
then do a fine-tuning pass of 6,000 steps with a 10x smaller rate.
If you want to experiment with customizing models, a good place to start is by
tweaking the spectrogram creation parameters. This has the effect of altering
the size of the input image to the model, and the creation code in
[models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py)
will adjust the number of computations and weights automatically to fit with
different dimensions. If you make the input smaller, the model will need fewer
computations to process it, so it can be a great way to trade off some accuracy
for improved latency. The `--window_stride_ms` controls how far apart each
frequency analysis sample is from the previous. If you increase this value, then
fewer samples will be taken for a given duration, and the time axis of the input
will shrink. The `--dct_coefficient_count` flag controls how many buckets are
used for the frequency counting, so reducing this will shrink the input in the
other dimension. The `--window_size_ms` argument doesn't affect the size, but
does control how wide the area used to calculate the frequencies is for each
sample. Reducing the duration of the training samples, controlled by
`--clip_duration_ms`, can also help if the sounds you're looking for are short,
since that also reduces the time dimension of the input. You'll need to make
sure that all your training data contains the right audio in the initial portion
of the clip though.
If you have an entirely different model in mind for your problem, you may find
that you can plug it into
[models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py)
and have the rest of the script handle all of the preprocessing and training
mechanics. You would add a new clause to `create_model`, looking for the name of
your architecture and then calling a model creation function. This function is
given the size of the spectrogram input, along with other model information, and
is expected to create TensorFlow ops to read that in and produce an output
prediction vector, and a placeholder to control the dropout rate. The rest of
the script will handle integrating this model into a larger graph doing the
input calculations and applying softmax and a loss function to train it.
One common problem when you're adjusting models and training hyper-parameters is
that not-a-number values can creep in, thanks to numerical precision issues. In
general you can solve these by reducing the magnitude of things like learning
rates and weight initialization functions, but if they're persistent you can
enable the `--check_nans` flag to track down the source of the errors. This will
insert check ops between most regular operations in TensorFlow, and abort the
training process with a useful error message when they're encountered.

View File

@ -0,0 +1,258 @@
package(
default_visibility = [
"//visibility:public",
],
)
licenses(["notice"]) # Apache 2.0
exports_files([
"LICENSE",
])
load("//tensorflow:tensorflow.bzl", "tf_py_test")
py_library(
name = "models",
srcs = [
"models.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
tf_py_test(
name = "models_test",
size = "small",
srcs = ["models_test.py"],
additional_deps = [
":models",
"//tensorflow/python:client_testlib",
],
)
py_library(
name = "input_data",
srcs = [
"input_data.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
tf_py_test(
name = "input_data_test",
size = "small",
srcs = ["input_data_test.py"],
additional_deps = [
":input_data",
"//tensorflow/python:client_testlib",
],
)
py_binary(
name = "train",
srcs = [
"train.py",
],
srcs_version = "PY2AND3",
deps = [
":input_data",
":models",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_binary(
name = "freeze",
srcs = [
"freeze.py",
],
srcs_version = "PY2AND3",
deps = [
":input_data",
":models",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
tf_py_test(
name = "freeze_test",
size = "small",
srcs = ["freeze_test.py"],
additional_deps = [
":freeze",
"//tensorflow/python:client_testlib",
],
)
py_binary(
name = "generate_streaming_test_wav",
srcs = [
"generate_streaming_test_wav.py",
],
srcs_version = "PY2AND3",
deps = [
":input_data",
":models",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
tf_py_test(
name = "generate_streaming_test_wav_test",
size = "small",
srcs = ["generate_streaming_test_wav_test.py"],
additional_deps = [
":generate_streaming_test_wav",
"//tensorflow/python:client_testlib",
],
)
cc_binary(
name = "label_wav_cc",
srcs = [
"label_wav.cc",
],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)
py_binary(
name = "label_wav",
srcs = [
"label_wav.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
tf_py_test(
name = "label_wav_test",
size = "medium",
srcs = ["label_wav_test.py"],
additional_deps = [
":label_wav",
"//tensorflow/python:client_testlib",
],
)
cc_library(
name = "recognize_commands",
srcs = [
"recognize_commands.cc",
],
hdrs = [
"recognize_commands.h",
],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)
cc_test(
name = "recognize_commands_test",
size = "medium",
srcs = [
"recognize_commands_test.cc",
],
deps = [
":recognize_commands",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "accuracy_utils",
srcs = [
"accuracy_utils.cc",
],
hdrs = [
"accuracy_utils.h",
],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)
cc_test(
name = "accuracy_utils_test",
size = "medium",
srcs = [
"accuracy_utils_test.cc",
],
deps = [
":accuracy_utils",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_binary(
name = "test_streaming_accuracy",
srcs = [
"test_streaming_accuracy.cc",
],
deps = [
":accuracy_utils",
":recognize_commands",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,4 @@
# Speech Commands Example
This is a basic speech recognition example. For more information, see the
tutorial at http://tensorflow.org/tutorials/audio_recognition.

View File

@ -0,0 +1,138 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/speech_commands/accuracy_utils.h"
#include <fstream>
#include <iomanip>
#include <unordered_set>
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
Status ReadGroundTruthFile(const string& file_name,
std::vector<std::pair<string, int64>>* result) {
std::ifstream file(file_name);
if (!file) {
return tensorflow::errors::NotFound("Ground truth file '", file_name,
"' not found.");
}
result->clear();
string line;
while (std::getline(file, line)) {
std::vector<string> pieces = tensorflow::str_util::Split(line, ',');
if (pieces.size() != 2) {
continue;
}
float timestamp;
if (!tensorflow::strings::safe_strtof(pieces[1].c_str(), &timestamp)) {
return tensorflow::errors::InvalidArgument(
"Wrong number format at line: ", line);
}
string label = pieces[0];
auto timestamp_int64 = static_cast<int64>(timestamp);
result->push_back({label, timestamp_int64});
}
std::sort(result->begin(), result->end(),
[](const std::pair<string, int64>& left,
const std::pair<string, int64>& right) {
return left.second < right.second;
});
return Status::OK();
}
void CalculateAccuracyStats(
const std::vector<std::pair<string, int64>>& ground_truth_list,
const std::vector<std::pair<string, int64>>& found_words,
int64 up_to_time_ms, int64 time_tolerance_ms,
StreamingAccuracyStats* stats) {
int64 latest_possible_time;
if (up_to_time_ms == -1) {
latest_possible_time = std::numeric_limits<int64>::max();
} else {
latest_possible_time = up_to_time_ms + time_tolerance_ms;
}
stats->how_many_ground_truth_words = 0;
for (const std::pair<string, int64>& ground_truth : ground_truth_list) {
const int64 ground_truth_time = ground_truth.second;
if (ground_truth_time > latest_possible_time) {
break;
}
++stats->how_many_ground_truth_words;
}
stats->how_many_false_positives = 0;
stats->how_many_correct_words = 0;
stats->how_many_wrong_words = 0;
std::unordered_set<int64> has_ground_truth_been_matched;
for (const std::pair<string, int64>& found_word : found_words) {
const string& found_label = found_word.first;
const int64 found_time = found_word.second;
const int64 earliest_time = found_time - time_tolerance_ms;
const int64 latest_time = found_time + time_tolerance_ms;
bool has_match_been_found = false;
for (const std::pair<string, int64>& ground_truth : ground_truth_list) {
const int64 ground_truth_time = ground_truth.second;
if ((ground_truth_time > latest_time) ||
(ground_truth_time > latest_possible_time)) {
break;
}
if (ground_truth_time < earliest_time) {
continue;
}
const string& ground_truth_label = ground_truth.first;
if ((ground_truth_label == found_label) &&
(has_ground_truth_been_matched.count(ground_truth_time) == 0)) {
++stats->how_many_correct_words;
} else {
++stats->how_many_wrong_words;
}
has_ground_truth_been_matched.insert(ground_truth_time);
has_match_been_found = true;
break;
}
if (!has_match_been_found) {
++stats->how_many_false_positives;
}
}
stats->how_many_ground_truth_matched = has_ground_truth_been_matched.size();
}
void PrintAccuracyStats(const StreamingAccuracyStats& stats) {
if (stats.how_many_ground_truth_words == 0) {
LOG(INFO) << "No ground truth yet, " << stats.how_many_false_positives
<< " false positives";
} else {
float any_match_percentage =
(stats.how_many_ground_truth_matched * 100.0f) /
stats.how_many_ground_truth_words;
float correct_match_percentage = (stats.how_many_correct_words * 100.0f) /
stats.how_many_ground_truth_words;
float wrong_match_percentage = (stats.how_many_wrong_words * 100.0f) /
stats.how_many_ground_truth_words;
float false_positive_percentage =
(stats.how_many_false_positives * 100.0f) /
stats.how_many_ground_truth_words;
LOG(INFO) << std::setprecision(1) << std::fixed << any_match_percentage
<< "% matched, " << correct_match_percentage << "% correctly, "
<< wrong_match_percentage << "% wrongly, "
<< false_positive_percentage << "% false positives ";
}
}
} // namespace tensorflow

View File

@ -0,0 +1,60 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
struct StreamingAccuracyStats {
StreamingAccuracyStats()
: how_many_ground_truth_words(0),
how_many_ground_truth_matched(0),
how_many_false_positives(0),
how_many_correct_words(0),
how_many_wrong_words(0) {}
int32 how_many_ground_truth_words;
int32 how_many_ground_truth_matched;
int32 how_many_false_positives;
int32 how_many_correct_words;
int32 how_many_wrong_words;
};
// Takes a file name, and loads a list of expected word labels and times from
// it, as comma-separated variables.
Status ReadGroundTruthFile(const string& file_name,
std::vector<std::pair<string, int64>>* result);
// Given ground truth labels and corresponding predictions found by a model,
// figure out how many were correct. Takes a time limit, so that only
// predictions up to a point in time are considered, in case we're evaluating
// accuracy when the model has only been run on part of the stream.
void CalculateAccuracyStats(
const std::vector<std::pair<string, int64>>& ground_truth_list,
const std::vector<std::pair<string, int64>>& found_words,
int64 up_to_time_ms, int64 time_tolerance_ms,
StreamingAccuracyStats* stats);
// Writes a human-readable description of the statistics to stdout.
void PrintAccuracyStats(const StreamingAccuracyStats& stats);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_

View File

@ -0,0 +1,59 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/speech_commands/accuracy_utils.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
TEST(AccuracyUtilsTest, ReadGroundTruthFile) {
string file_name = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(),
"ground_truth.txt");
string file_data = "a,10\nb,12\n";
TF_ASSERT_OK(WriteStringToFile(Env::Default(), file_name, file_data));
std::vector<std::pair<string, int64>> ground_truth;
TF_ASSERT_OK(ReadGroundTruthFile(file_name, &ground_truth));
ASSERT_EQ(2, ground_truth.size());
EXPECT_EQ("a", ground_truth[0].first);
EXPECT_EQ(10, ground_truth[0].second);
EXPECT_EQ("b", ground_truth[1].first);
EXPECT_EQ(12, ground_truth[1].second);
}
TEST(AccuracyUtilsTest, CalculateAccuracyStats) {
StreamingAccuracyStats stats;
CalculateAccuracyStats({{"a", 1000}, {"b", 9000}},
{{"a", 1200}, {"b", 5000}, {"a", 8700}}, 10000, 500,
&stats);
EXPECT_EQ(2, stats.how_many_ground_truth_words);
EXPECT_EQ(2, stats.how_many_ground_truth_matched);
EXPECT_EQ(1, stats.how_many_false_positives);
EXPECT_EQ(1, stats.how_many_correct_words);
EXPECT_EQ(1, stats.how_many_wrong_words);
}
TEST(AccuracyUtilsTest, PrintAccuracyStats) {
StreamingAccuracyStats stats;
PrintAccuracyStats(stats);
}
} // namespace tensorflow

View File

@ -0,0 +1,167 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts a trained checkpoint into a frozen model for mobile inference.
Once you've trained a model using the `train.py` script, you can use this tool
to convert it into a binary GraphDef file that can be loaded into the Android,
iOS, or Raspberry Pi example code. Here's an example of how to run it:
bazel run tensorflow/examples/speech_commands/freeze -- \
--sample_rate=16000 --dct_coefficient_count=40 --window_size_ms=20 \
--window_stride_ms=10 --clip_duration_ms=1000 \
--model_architecture=conv \
--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-1300 \
--output_file=/tmp/my_frozen_graph.pb
One thing to watch out for is that you need to pass in the same arguments for
`sample_rate` and other command line variables here as you did for the training
script.
The resulting graph has an input for WAV-encoded data named 'wav_data', one for
raw PCM data (as floats in the range -1.0 to 1.0) called 'decoded_sample_data',
and the output is called 'labels_softmax'.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import sys
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
import tensorflow.examples.speech_commands.input_data as input_data
import tensorflow.examples.speech_commands.models as models
from tensorflow.python.framework import graph_util
FLAGS = None
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
window_size_ms, window_stride_ms,
dct_coefficient_count, model_architecture):
"""Creates an audio model with the nodes needed for inference.
Uses the supplied arguments to create a model, and inserts the input and
output nodes that are needed to use the graph for inference.
Args:
wanted_words: Comma-separated list of the words we're trying to recognize.
sample_rate: How many samples per second are in the input audio files.
clip_duration_ms: How many samples to analyze for the audio pattern.
window_size_ms: Time slice duration to estimate frequencies from.
window_stride_ms: How far apart time slices should be.
dct_coefficient_count: Number of frequency bands to analyze.
model_architecture: Name of the kind of model to generate.
"""
words_list = input_data.prepare_words_list(wanted_words.split(','))
model_settings = models.prepare_model_settings(
len(words_list), sample_rate, clip_duration_ms, window_size_ms,
window_stride_ms, dct_coefficient_count)
wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
decoded_sample_data = contrib_audio.decode_wav(
wav_data_placeholder,
desired_channels=1,
desired_samples=model_settings['desired_samples'],
name='decoded_sample_data')
spectrogram = contrib_audio.audio_spectrogram(
decoded_sample_data.audio,
window_size=model_settings['window_size_samples'],
stride=model_settings['window_stride_samples'],
magnitude_squared=True)
fingerprint_input = contrib_audio.mfcc(
spectrogram,
decoded_sample_data.sample_rate,
dct_coefficient_count=dct_coefficient_count)
logits = models.create_model(
fingerprint_input, model_settings, model_architecture, is_training=False)
# Create an output to use for inference.
tf.nn.softmax(logits, name='labels_softmax')
def main(_):
# Create the model and load its weights.
sess = tf.InteractiveSession()
create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate,
FLAGS.clip_duration_ms, FLAGS.window_size_ms,
FLAGS.window_stride_ms, FLAGS.dct_coefficient_count,
FLAGS.model_architecture)
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
# Turn all the variables into inline constants inside the graph and save it.
frozen_graph_def = graph_util.convert_variables_to_constants(
sess, sess.graph_def, ['labels_softmax'])
tf.train.write_graph(
frozen_graph_def,
os.path.dirname(FLAGS.output_file),
os.path.basename(FLAGS.output_file),
as_text=False)
tf.logging.info('Saved frozen graph to %s', FLAGS.output_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--sample_rate',
type=int,
default=16000,
help='Expected sample rate of the wavs',)
parser.add_argument(
'--clip_duration_ms',
type=int,
default=1000,
help='Expected duration in milliseconds of the wavs',)
parser.add_argument(
'--window_size_ms',
type=float,
default=20.0,
help='How long each spectrogram timeslice is',)
parser.add_argument(
'--window_stride_ms',
type=float,
default=10.0,
help='How long each spectrogram timeslice is',)
parser.add_argument(
'--dct_coefficient_count',
type=int,
default=40,
help='How many bins to use for the MFCC fingerprint',)
parser.add_argument(
'--start_checkpoint',
type=str,
default='',
help='If specified, restore this pretrained model before any training.')
parser.add_argument(
'--model_architecture',
type=str,
default='conv',
help='What model architecture to use')
parser.add_argument(
'--wanted_words',
type=str,
default='yes,no,up,down,left,right,on,off,stop,go',
help='Words to use (others will be added to an unknown label)',)
parser.add_argument(
'--output_file', type=str, help='Where to save the frozen graph.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,38 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data input for speech commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.examples.speech_commands import freeze
from tensorflow.python.platform import test
class FreezeTest(test.TestCase):
def testCreateInferenceGraph(self):
with self.test_session() as sess:
freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 20.0, 10.0, 40,
'conv')
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
self.assertIsNotNone(
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,281 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Saves out a .wav file with synthesized conversational data and labels.
The best way to estimate the real-world performance of an audio recognition
model is by running it against a continuous stream of data, the way that it
would be used in an application. Training evaluations are only run against
discrete individual samples, so the results aren't as realistic.
To make it easy to run evaluations against audio streams, this script uses
samples from the testing partition of the data set, mixes them in at random
positions together with background noise, and saves out the result as one long
audio file.
Here's an example of generating a test file:
bazel run tensorflow/examples/speech_commands:generate_streaming_test_wav -- \
--data_dir=/tmp/my_wavs --background_dir=/tmp/my_backgrounds \
--background_volume=0.1 --test_duration_seconds=600 \
--output_audio_file=/tmp/streaming_test.wav \
--output_labels_file=/tmp/streaming_test_labels.txt
Once you've created a streaming audio file, you can then use the
test_streaming_accuracy tool to calculate accuracy metrics for a model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import math
import sys
import numpy as np
import tensorflow as tf
import tensorflow.examples.speech_commands.input_data as input_data
import tensorflow.examples.speech_commands.models as models
FLAGS = None
def mix_in_audio_sample(track_data, track_offset, sample_data, sample_offset,
clip_duration, sample_volume, ramp_in, ramp_out):
"""Mixes the sample data into the main track at the specified offset.
Args:
track_data: Numpy array holding main audio data. Modified in-place.
track_offset: Where to mix the sample into the main track.
sample_data: Numpy array of audio data to mix into the main track.
sample_offset: Where to start in the audio sample.
clip_duration: How long the sample segment is.
sample_volume: Loudness to mix the sample in at.
ramp_in: Length in samples of volume increase stage.
ramp_out: Length in samples of volume decrease stage.
"""
ramp_out_index = clip_duration - ramp_out
track_end = min(track_offset + clip_duration, track_data.shape[0])
track_end = min(track_end,
track_offset + (sample_data.shape[0] - sample_offset))
sample_range = track_end - track_offset
for i in range(sample_range):
if i < ramp_in:
envelope_scale = i / ramp_in
elif i > ramp_out_index:
envelope_scale = (clip_duration - i) / ramp_out
else:
envelope_scale = 1
sample_input = sample_data[sample_offset + i]
track_data[track_offset
+ i] += sample_input * envelope_scale * sample_volume
def main(_):
audio_processor = input_data.AudioProcessor(
'', FLAGS.data_dir, FLAGS.silence_percentage, 10,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
FLAGS.testing_percentage)
words_list = input_data.prepare_words_list(FLAGS.wanted_words.split(','))
model_settings = models.prepare_model_settings(
len(words_list), FLAGS.sample_rate, FLAGS.clip_duration_ms,
FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
output_audio_sample_count = FLAGS.sample_rate * FLAGS.test_duration_seconds
output_audio = np.zeros((output_audio_sample_count,), dtype=np.float32)
# Set up background audio.
background_crossover_ms = 500
background_segment_duration_ms = (
FLAGS.clip_duration_ms + background_crossover_ms)
background_segment_duration_samples = int(
(background_segment_duration_ms * FLAGS.sample_rate) / 1000)
background_segment_stride_samples = int(
(FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000)
background_ramp_samples = int(
((background_crossover_ms / 2) * FLAGS.sample_rate) / 1000)
# Mix the background audio into the main track.
how_many_backgrounds = int(
math.ceil(output_audio_sample_count / background_segment_stride_samples))
for i in range(how_many_backgrounds):
output_offset = int(i * background_segment_stride_samples)
background_index = np.random.randint(len(audio_processor.background_data))
background_samples = audio_processor.background_data[background_index]
background_offset = np.random.randint(
0, len(background_samples) - model_settings['desired_samples'])
background_volume = np.random.uniform(0, FLAGS.background_volume)
mix_in_audio_sample(output_audio, output_offset, background_samples,
background_offset, background_segment_duration_samples,
background_volume, background_ramp_samples,
background_ramp_samples)
# Mix the words into the main track, noting their labels and positions.
output_labels = []
word_stride_ms = FLAGS.clip_duration_ms + FLAGS.word_gap_ms
word_stride_samples = int((word_stride_ms * FLAGS.sample_rate) / 1000)
clip_duration_samples = int(
(FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000)
word_gap_samples = int((FLAGS.word_gap_ms * FLAGS.sample_rate) / 1000)
how_many_words = int(
math.floor(output_audio_sample_count / word_stride_samples))
all_test_data, all_test_labels = audio_processor.get_unprocessed_data(
-1, model_settings, 'testing')
for i in range(how_many_words):
output_offset = (
int(i * word_stride_samples) + np.random.randint(word_gap_samples))
output_offset_ms = (output_offset * 1000) / FLAGS.sample_rate
is_unknown = np.random.randint(100) < FLAGS.unknown_percentage
if is_unknown:
wanted_label = input_data.UNKNOWN_WORD_LABEL
else:
wanted_label = words_list[2 + np.random.randint(len(words_list) - 2)]
test_data_start = np.random.randint(len(all_test_data))
found_sample_data = None
index_lookup = np.arange(len(all_test_data), dtype=np.int32)
np.random.shuffle(index_lookup)
for test_data_offset in range(len(all_test_data)):
test_data_index = index_lookup[(
test_data_start + test_data_offset) % len(all_test_data)]
current_label = all_test_labels[test_data_index]
if current_label == wanted_label:
found_sample_data = all_test_data[test_data_index]
break
mix_in_audio_sample(output_audio, output_offset, found_sample_data, 0,
clip_duration_samples, 1.0, 500, 500)
output_labels.append({'label': wanted_label, 'time': output_offset_ms})
input_data.save_wav_file(FLAGS.output_audio_file, output_audio,
FLAGS.sample_rate)
tf.logging.info('Saved streaming test wav to %s', FLAGS.output_audio_file)
with open(FLAGS.output_labels_file, 'w') as f:
for output_label in output_labels:
f.write('%s, %f\n' % (output_label['label'], output_label['time']))
tf.logging.info('Saved streaming test labels to %s', FLAGS.output_labels_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_url',
type=str,
# pylint: disable=line-too-long
default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
# pylint: enable=line-too-long
help='Location of speech training data')
parser.add_argument(
'--data_dir',
type=str,
default='/tmp/speech_dataset',
help="""\
Where to download the speech training data to.
""")
parser.add_argument(
'--background_dir',
type=str,
default='',
help="""\
Path to a directory of .wav files to mix in as background noise during training.
""")
parser.add_argument(
'--background_volume',
type=float,
default=0.1,
help="""\
How loud the background noise should be, between 0 and 1.
""")
parser.add_argument(
'--background_frequency',
type=float,
default=0.8,
help="""\
How many of the training samples have background noise mixed in.
""")
parser.add_argument(
'--silence_percentage',
type=float,
default=10.0,
help="""\
How much of the training data should be silence.
""")
parser.add_argument(
'--testing_percentage',
type=int,
default=10,
help='What percentage of wavs to use as a test set.')
parser.add_argument(
'--validation_percentage',
type=int,
default=10,
help='What percentage of wavs to use as a validation set.')
parser.add_argument(
'--sample_rate',
type=int,
default=16000,
help='Expected sample rate of the wavs.',)
parser.add_argument(
'--clip_duration_ms',
type=int,
default=1000,
help='Expected duration in milliseconds of the wavs.',)
parser.add_argument(
'--window_size_ms',
type=float,
default=20.0,
help='How long each spectrogram timeslice is',)
parser.add_argument(
'--window_stride_ms',
type=float,
default=10.0,
help='How long each spectrogram timeslice is',)
parser.add_argument(
'--dct_coefficient_count',
type=int,
default=40,
help='How many bins to use for the MFCC fingerprint',)
parser.add_argument(
'--wanted_words',
type=str,
default='yes,no,up,down,left,right,on,off,stop,go',
help='Words to use (others will be added to an unknown label)',)
parser.add_argument(
'--output_audio_file',
type=str,
default='/tmp/speech_commands_train/streaming_test.wav',
help='File to save the generated test audio to.')
parser.add_argument(
'--output_labels_file',
type=str,
default='/tmp/speech_commands_train/streaming_test_labels.txt',
help='File to save the generated test labels to.')
parser.add_argument(
'--test_duration_seconds',
type=int,
default=600,
help='How long the generated test audio file should be.',)
parser.add_argument(
'--word_gap_ms',
type=int,
default=2000,
help='How long the average gap should be between words.',)
parser.add_argument(
'--unknown_percentage',
type=int,
default=30,
help='What percentage of words should be unknown.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,39 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for test file generation for speech commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.examples.speech_commands import generate_streaming_test_wav
from tensorflow.python.platform import test
class GenerateStreamingTestWavTest(test.TestCase):
def testMixInAudioSample(self):
track_data = np.zeros([10000])
sample_data = np.ones([1000])
generate_streaming_test_wav.mix_in_audio_sample(
track_data, 2000, sample_data, 0, 1000, 1.0, 100, 100)
self.assertNear(1.0, track_data[2500], 0.0001)
self.assertNear(0.0, track_data[3500], 0.0001)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,532 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model definitions for simple speech recognition.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import math
import os.path
import random
import re
import sys
import tarfile
import numpy as np
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import compat
MAX_NUM_WAVS_PER_CLASS = 2**27 - 1 # ~134M
SILENCE_LABEL = '_silence_'
SILENCE_INDEX = 0
UNKNOWN_WORD_LABEL = '_unknown_'
UNKNOWN_WORD_INDEX = 1
BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
RANDOM_SEED = 59185
def prepare_words_list(wanted_words):
"""Prepends common tokens to the custom word list.
Args:
wanted_words: List of strings containing the custom words.
Returns:
List with the standard silence and unknown tokens added.
"""
return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words
def which_set(filename, validation_percentage, testing_percentage):
"""Determines which data partition the file should belong to.
We want to keep files in the same training, validation, or testing sets even
if new ones are added over time. This makes it less likely that testing
samples will accidentally be reused in training when long runs are restarted
for example. To keep this stability, a hash of the filename is taken and used
to determine which set it should belong to. This determination only depends on
the name and the set proportions, so it won't change as other files are added.
It's also useful to associate particular files as related (for example words
spoken by the same person), so anything after '_nohash_' in a filename is
ignored for set determination. This ensures that 'bobby_nohash_0.wav' and
'bobby_nohash_1.wav' are always in the same set, for example.
Args:
filename: File path of the data sample.
validation_percentage: How much of the data set to use for validation.
testing_percentage: How much of the data set to use for testing.
Returns:
String, one of 'training', 'validation', or 'testing'.
"""
base_name = os.path.basename(filename)
# We want to ignore anything after '_nohash_' in the file name when
# deciding which set to put a wav in, so the data set creator has a way of
# grouping wavs that are close variations of each other.
hash_name = re.sub(r'_nohash_.*$', '', base_name)
# This looks a bit magical, but we need to decide whether this file should
# go into the training, testing, or validation sets, and we want to keep
# existing files in the same set even if more files are subsequently
# added.
# To do that, we need a stable way of deciding based on just the file name
# itself, so we do a hash of that and then use that to generate a
# probability value that we use to assign it.
hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
percentage_hash = ((int(hash_name_hashed, 16) %
(MAX_NUM_WAVS_PER_CLASS + 1)) *
(100.0 / MAX_NUM_WAVS_PER_CLASS))
if percentage_hash < validation_percentage:
result = 'validation'
elif percentage_hash < (testing_percentage + validation_percentage):
result = 'testing'
else:
result = 'training'
return result
def load_wav_file(filename):
"""Loads an audio file and returns a float PCM-encoded array of samples.
Args:
filename: Path to the .wav file to load.
Returns:
Numpy array holding the sample data as floats between -1.0 and 1.0.
"""
with tf.Session(graph=tf.Graph()) as sess:
wav_filename_placeholder = tf.placeholder(tf.string, [])
wav_loader = io_ops.read_file(wav_filename_placeholder)
wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
return sess.run(
wav_decoder,
feed_dict={wav_filename_placeholder: filename}).audio.flatten()
def save_wav_file(filename, wav_data, sample_rate):
"""Saves audio sample data to a .wav audio file.
Args:
filename: Path to save the file to.
wav_data: 2D array of float PCM-encoded audio data.
sample_rate: Samples per second to encode in the file.
"""
with tf.Session(graph=tf.Graph()) as sess:
wav_filename_placeholder = tf.placeholder(tf.string, [])
sample_rate_placeholder = tf.placeholder(tf.int32, [])
wav_data_placeholder = tf.placeholder(tf.float32, [None, 1])
wav_encoder = contrib_audio.encode_wav(wav_data_placeholder,
sample_rate_placeholder)
wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder)
sess.run(
wav_saver,
feed_dict={
wav_filename_placeholder: filename,
sample_rate_placeholder: sample_rate,
wav_data_placeholder: np.reshape(wav_data, (-1, 1))
})
class AudioProcessor(object):
"""Handles loading, partitioning, and preparing audio training data."""
def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
wanted_words, validation_percentage, testing_percentage,
model_settings):
self.data_dir = data_dir
self.maybe_download_and_extract_dataset(data_url, data_dir)
self.prepare_data_index(silence_percentage, unknown_percentage,
wanted_words, validation_percentage,
testing_percentage)
self.prepare_background_data()
self.prepare_processing_graph(model_settings)
def maybe_download_and_extract_dataset(self, data_url, dest_directory):
"""Download and extract data set tar file.
If the data set we're using doesn't already exist, this function
downloads it from the TensorFlow.org website and unpacks it into a
directory.
If the data_url is none, don't download anything and expect the data
directory to contain the correct files already.
Args:
data_url: Web location of the tar file containing the data set.
dest_directory: File path to extract data to.
"""
if not data_url:
return
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = data_url.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write(
'\r>> Downloading %s %.1f%%' %
(filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
try:
filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
except:
tf.logging.error('Failed to download URL: %s to folder: %s', data_url,
filepath)
tf.logging.error('Please make sure you have enough free space and'
' an internet connection')
raise
print()
statinfo = os.stat(filepath)
tf.logging.info('Successfully downloaded %s (%d bytes)', filename,
statinfo.st_size)
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def prepare_data_index(self, silence_percentage, unknown_percentage,
wanted_words, validation_percentage,
testing_percentage):
"""Prepares a list of the samples organized by set and label.
The training loop needs a list of all the available data, organized by
which partition it should belong to, and with ground truth labels attached.
This function analyzes the folders below the `data_dir`, figures out the
right
labels for each file based on the name of the subdirectory it belongs to,
and uses a stable hash to assign it to a data set partition.
Args:
silence_percentage: How much of the resulting data should be background.
unknown_percentage: How much should be audio outside the wanted classes.
wanted_words: Labels of the classes we want to be able to recognize.
validation_percentage: How much of the data set to use for validation.
testing_percentage: How much of the data set to use for testing.
Returns:
Dictionary containing a list of file information for each set partition,
and a lookup map for each class to determine its numeric index.
Raises:
Exception: If expected files are not found.
"""
# Make sure the shuffling and picking of unknowns is deterministic.
random.seed(RANDOM_SEED)
wanted_words_index = {}
for index, wanted_word in enumerate(wanted_words):
wanted_words_index[wanted_word] = index + 2
self.data_index = {'validation': [], 'testing': [], 'training': []}
unknown_index = {'validation': [], 'testing': [], 'training': []}
all_words = {}
# Look through all the subfolders to find audio samples
search_path = os.path.join(self.data_dir, '*', '*.wav')
for wav_path in gfile.Glob(search_path):
word = re.search('.*/([^/]+)/.*.wav', wav_path).group(1).lower()
# Treat the '_background_noise_' folder as a special case, since we expect
# it to contain long audio samples we mix in to improve training.
if word == BACKGROUND_NOISE_DIR_NAME:
continue
all_words[word] = True
set_index = which_set(wav_path, validation_percentage, testing_percentage)
# If it's a known class, store its detail, otherwise add it to the list
# we'll use to train the unknown label.
if word in wanted_words_index:
self.data_index[set_index].append({'label': word, 'file': wav_path})
else:
unknown_index[set_index].append({'label': word, 'file': wav_path})
if not all_words:
raise Exception('No .wavs found at ' + search_path)
for index, wanted_word in enumerate(wanted_words):
if wanted_word not in all_words:
raise Exception('Expected to find ' + wanted_word +
' in labels but only found ' +
', '.join(all_words.keys()))
# We need an arbitrary file to load as the input for the silence samples.
# It's multiplied by zero later, so the content doesn't matter.
silence_wav_path = self.data_index['training'][0]['file']
for set_index in ['validation', 'testing', 'training']:
set_size = len(self.data_index[set_index])
silence_size = int(math.ceil(set_size * silence_percentage / 100))
for _ in range(silence_size):
self.data_index[set_index].append({
'label': SILENCE_LABEL,
'file': silence_wav_path
})
# Pick some unknowns to add to each partition of the data set.
random.shuffle(unknown_index[set_index])
unknown_size = int(math.ceil(set_size * unknown_percentage / 100))
self.data_index[set_index].extend(unknown_index[set_index][:unknown_size])
# Make sure the ordering is random.
for set_index in ['validation', 'testing', 'training']:
random.shuffle(self.data_index[set_index])
# Prepare the rest of the result data structure.
self.words_list = prepare_words_list(wanted_words)
self.word_to_index = {}
for word in all_words:
if word in wanted_words_index:
self.word_to_index[word] = wanted_words_index[word]
else:
self.word_to_index[word] = UNKNOWN_WORD_INDEX
self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX
def prepare_background_data(self):
"""Searches a folder for background noise audio, and loads it into memory.
It's expected that the background audio samples will be in a subdirectory
named '_background_noise_' inside the 'data_dir' folder, as .wavs that match
the sample rate of the training data, but can be much longer in duration.
If the '_background_noise_' folder doesn't exist at all, this isn't an
error, it's just taken to mean that no background noise augmentation should
be used. If the folder does exist, but it's empty, that's treated as an
error.
Returns:
List of raw PCM-encoded audio samples of background noise.
Raises:
Exception: If files aren't found in the folder.
"""
self.background_data = []
background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME)
if not os.path.exists(background_dir):
return self.background_data
with tf.Session(graph=tf.Graph()) as sess:
wav_filename_placeholder = tf.placeholder(tf.string, [])
wav_loader = io_ops.read_file(wav_filename_placeholder)
wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME,
'*.wav')
for wav_path in gfile.Glob(search_path):
wav_data = sess.run(
wav_decoder,
feed_dict={wav_filename_placeholder: wav_path}).audio.flatten()
self.background_data.append(wav_data)
if not self.background_data:
raise Exception('No background wav files were found in ' + search_path)
def prepare_processing_graph(self, model_settings):
"""Builds a TensorFlow graph to apply the input distortions.
Creates a graph that loads a WAVE file, decodes it, scales the volume,
shifts it in time, adds in background noise, calculates a spectrogram, and
then builds an MFCC fingerprint from that.
This must be called with an active TensorFlow session running, and it
creates multiple placeholder inputs, and one output:
- wav_filename_placeholder_: Filename of the WAV to load.
- foreground_volume_placeholder_: How loud the main clip should be.
- time_shift_padding_placeholder_: Where to pad the clip.
- time_shift_offset_placeholder_: How much to move the clip in time.
- background_data_placeholder_: PCM sample data for background noise.
- background_volume_placeholder_: Loudness of mixed-in background.
- mfcc_: Output 2D fingerprint of processed audio.
Args:
model_settings: Information about the current model being trained.
"""
desired_samples = model_settings['desired_samples']
self.wav_filename_placeholder_ = tf.placeholder(tf.string, [])
wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
wav_decoder = contrib_audio.decode_wav(
wav_loader, desired_channels=1, desired_samples=desired_samples)
# Allow the audio sample's volume to be adjusted.
self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, [])
scaled_foreground = tf.multiply(wav_decoder.audio,
self.foreground_volume_placeholder_)
# Shift the sample's start position, and pad any gaps with zeros.
self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2])
self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2])
padded_foreground = tf.pad(
scaled_foreground,
self.time_shift_padding_placeholder_,
mode='CONSTANT')
sliced_foreground = tf.slice(padded_foreground,
self.time_shift_offset_placeholder_,
[desired_samples, -1])
# Mix in background noise.
self.background_data_placeholder_ = tf.placeholder(tf.float32,
[desired_samples, 1])
self.background_volume_placeholder_ = tf.placeholder(tf.float32, [])
background_mul = tf.multiply(self.background_data_placeholder_,
self.background_volume_placeholder_)
background_add = tf.add(background_mul, sliced_foreground)
background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
# Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
spectrogram = contrib_audio.audio_spectrogram(
background_clamp,
window_size=model_settings['window_size_samples'],
stride=model_settings['window_stride_samples'],
magnitude_squared=True)
self.mfcc_ = contrib_audio.mfcc(
spectrogram,
wav_decoder.sample_rate,
dct_coefficient_count=model_settings['dct_coefficient_count'])
def set_size(self, mode):
"""Calculates the number of samples in the dataset partition.
Args:
mode: Which partition, must be 'training', 'validation', or 'testing'.
Returns:
Number of samples in the partition.
"""
return len(self.data_index[mode])
def get_data(self, how_many, offset, model_settings, background_frequency,
background_volume_range, time_shift, mode, sess):
"""Gather samples from the data set, applying transformations as needed.
When the mode is 'training', a random selection of samples will be returned,
otherwise the first N clips in the partition will be used. This ensures that
validation always uses the same samples, reducing noise in the metrics.
Args:
how_many: Desired number of samples to return. -1 means the entire
contents of this partition.
offset: Where to start when fetching deterministically.
model_settings: Information about the current model being trained.
background_frequency: How many clips will have background noise, 0.0 to
1.0.
background_volume_range: How loud the background noise will be.
time_shift: How much to randomly shift the clips by in time.
mode: Which partition to use, must be 'training', 'validation', or
'testing'.
sess: TensorFlow session that was active when processor was created.
Returns:
List of sample data for the transformed samples, and list of labels in
one-hot form.
"""
# Pick one of the partitions to choose samples from.
candidates = self.data_index[mode]
if how_many == -1:
sample_count = len(candidates)
else:
sample_count = max(0, min(how_many, len(candidates) - offset))
# Data and labels will be populated and returned.
data = np.zeros((sample_count, model_settings['fingerprint_size']))
labels = np.zeros((sample_count, model_settings['label_count']))
desired_samples = model_settings['desired_samples']
use_background = self.background_data and (mode == 'training')
pick_deterministically = (mode != 'training')
# Use the processing graph we created earlier to repeatedly to generate the
# final output sample data we'll use in training.
for i in xrange(offset, offset + sample_count):
# Pick which audio sample to use.
if how_many == -1 or pick_deterministically:
sample_index = i
else:
sample_index = np.random.randint(len(candidates))
sample = candidates[sample_index]
# If we're time shifting, set up the offset for this sample.
if time_shift > 0:
time_shift_amount = np.random.randint(-time_shift, time_shift)
else:
time_shift_amount = 0
if time_shift_amount > 0:
time_shift_padding = [[time_shift_amount, 0], [0, 0]]
time_shift_offset = [0, 0]
else:
time_shift_padding = [[0, -time_shift_amount], [0, 0]]
time_shift_offset = [-time_shift_amount, 0]
input_dict = {
self.wav_filename_placeholder_: sample['file'],
self.time_shift_padding_placeholder_: time_shift_padding,
self.time_shift_offset_placeholder_: time_shift_offset,
}
# Choose a section of background noise to mix in.
if use_background:
background_index = np.random.randint(len(self.background_data))
background_samples = self.background_data[background_index]
background_offset = np.random.randint(
0, len(background_samples) - model_settings['desired_samples'])
background_clipped = background_samples[background_offset:(
background_offset + desired_samples)]
background_reshaped = background_clipped.reshape([desired_samples, 1])
if np.random.uniform(0, 1) < background_frequency:
background_volume = np.random.uniform(0, background_volume_range)
else:
background_volume = 0
else:
background_reshaped = np.zeros([desired_samples, 1])
background_volume = 0
input_dict[self.background_data_placeholder_] = background_reshaped
input_dict[self.background_volume_placeholder_] = background_volume
# If we want silence, mute out the main sample but leave the background.
if sample['label'] == SILENCE_LABEL:
input_dict[self.foreground_volume_placeholder_] = 0
else:
input_dict[self.foreground_volume_placeholder_] = 1
# Run the graph to produce the output audio.
data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten()
label_index = self.word_to_index[sample['label']]
labels[i - offset, label_index] = 1
return data, labels
def get_unprocessed_data(self, how_many, model_settings, mode):
"""Retrieve sample data for the given partition, with no transformations.
Args:
how_many: Desired number of samples to return. -1 means the entire
contents of this partition.
model_settings: Information about the current model being trained.
mode: Which partition to use, must be 'training', 'validation', or
'testing'.
Returns:
List of sample data for the samples, and list of labels in one-hot form.
"""
candidates = self.data_index[mode]
if how_many == -1:
sample_count = len(candidates)
else:
sample_count = how_many
desired_samples = model_settings['desired_samples']
words_list = self.words_list
data = np.zeros((sample_count, desired_samples))
labels = []
with tf.Session(graph=tf.Graph()) as sess:
wav_filename_placeholder = tf.placeholder(tf.string, [])
wav_loader = io_ops.read_file(wav_filename_placeholder)
wav_decoder = contrib_audio.decode_wav(
wav_loader, desired_channels=1, desired_samples=desired_samples)
foreground_volume_placeholder = tf.placeholder(tf.float32, [])
scaled_foreground = tf.multiply(wav_decoder.audio,
foreground_volume_placeholder)
for i in range(sample_count):
if how_many == -1:
sample_index = i
else:
sample_index = np.random.randint(len(candidates))
sample = candidates[sample_index]
input_dict = {wav_filename_placeholder: sample['file']}
if sample['label'] == SILENCE_LABEL:
input_dict[foreground_volume_placeholder] = 0
else:
input_dict[foreground_volume_placeholder] = 1
data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten()
label_index = self.word_to_index[sample['label']]
labels.append(words_list[label_index])
return data, labels

View File

@ -0,0 +1,212 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data input for speech commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from tensorflow.examples.speech_commands import input_data
from tensorflow.python.platform import test
class InputDataTest(test.TestCase):
def _getWavData(self):
with self.test_session() as sess:
sample_data = tf.zeros([1000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
return wav_data
def _saveTestWavFile(self, filename, wav_data):
with open(filename, "wb") as f:
f.write(wav_data)
def _saveWavFolders(self, root_dir, labels, how_many):
wav_data = self._getWavData()
for label in labels:
dir_name = os.path.join(root_dir, label)
os.mkdir(dir_name)
for i in range(how_many):
file_path = os.path.join(dir_name, "some_audio_%d.wav" % i)
self._saveTestWavFile(file_path, wav_data)
def _model_settings(self):
return {
"desired_samples": 160,
"fingerprint_size": 40,
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
"dct_coefficient_count": 40,
}
def testPrepareWordsList(self):
words_list = ["a", "b"]
self.assertGreater(
len(input_data.prepare_words_list(words_list)), len(words_list))
def testWhichSet(self):
self.assertEqual(
input_data.which_set("foo.wav", 10, 10),
input_data.which_set("foo.wav", 10, 10))
self.assertEqual(
input_data.which_set("foo_nohash_0.wav", 10, 10),
input_data.which_set("foo_nohash_1.wav", 10, 10))
def testPrepareDataIndex(self):
tmp_dir = self.get_temp_dir()
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
10, 10, self._model_settings())
self.assertLess(0, audio_processor.set_size("training"))
self.assertTrue("training" in audio_processor.data_index)
self.assertTrue("validation" in audio_processor.data_index)
self.assertTrue("testing" in audio_processor.data_index)
self.assertEquals(input_data.UNKNOWN_WORD_INDEX,
audio_processor.word_to_index["c"])
def testPrepareDataIndexEmpty(self):
tmp_dir = self.get_temp_dir()
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0)
with self.assertRaises(Exception) as e:
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
self._model_settings())
self.assertTrue("No .wavs found" in str(e.exception))
def testPrepareDataIndexMissing(self):
tmp_dir = self.get_temp_dir()
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
with self.assertRaises(Exception) as e:
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
10, self._model_settings())
self.assertTrue("Expected to find" in str(e.exception))
def testPrepareBackgroundData(self):
tmp_dir = self.get_temp_dir()
background_dir = os.path.join(tmp_dir, "_background_noise_")
os.mkdir(background_dir)
wav_data = self._getWavData()
for i in range(10):
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
self._saveTestWavFile(file_path, wav_data)
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
10, 10, self._model_settings())
self.assertEqual(10, len(audio_processor.background_data))
def testLoadWavFile(self):
tmp_dir = self.get_temp_dir()
file_path = os.path.join(tmp_dir, "load_test.wav")
wav_data = self._getWavData()
self._saveTestWavFile(file_path, wav_data)
sample_data = input_data.load_wav_file(file_path)
self.assertIsNotNone(sample_data)
def testSaveWavFile(self):
tmp_dir = self.get_temp_dir()
file_path = os.path.join(tmp_dir, "load_test.wav")
save_data = np.zeros([16000, 1])
input_data.save_wav_file(file_path, save_data, 16000)
loaded_data = input_data.load_wav_file(file_path)
self.assertIsNotNone(loaded_data)
self.assertEqual(16000, len(loaded_data))
def testPrepareProcessingGraph(self):
tmp_dir = self.get_temp_dir()
wav_dir = os.path.join(tmp_dir, "wavs")
os.mkdir(wav_dir)
self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
background_dir = os.path.join(wav_dir, "_background_noise_")
os.mkdir(background_dir)
wav_data = self._getWavData()
for i in range(10):
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
self._saveTestWavFile(file_path, wav_data)
model_settings = {
"desired_samples": 160,
"fingerprint_size": 40,
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
"dct_coefficient_count": 40,
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
10, 10, model_settings)
self.assertIsNotNone(audio_processor.wav_filename_placeholder_)
self.assertIsNotNone(audio_processor.foreground_volume_placeholder_)
self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_)
self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_)
self.assertIsNotNone(audio_processor.background_data_placeholder_)
self.assertIsNotNone(audio_processor.background_volume_placeholder_)
self.assertIsNotNone(audio_processor.mfcc_)
def testGetData(self):
tmp_dir = self.get_temp_dir()
wav_dir = os.path.join(tmp_dir, "wavs")
os.mkdir(wav_dir)
self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
background_dir = os.path.join(wav_dir, "_background_noise_")
os.mkdir(background_dir)
wav_data = self._getWavData()
for i in range(10):
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
self._saveTestWavFile(file_path, wav_data)
model_settings = {
"desired_samples": 160,
"fingerprint_size": 40,
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
"dct_coefficient_count": 40,
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
10, 10, model_settings)
with self.test_session() as sess:
result_data, result_labels = audio_processor.get_data(
10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
self.assertEqual(10, len(result_data))
self.assertEqual(10, len(result_labels))
def testGetUnprocessedData(self):
tmp_dir = self.get_temp_dir()
wav_dir = os.path.join(tmp_dir, "wavs")
os.mkdir(wav_dir)
self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
model_settings = {
"desired_samples": 160,
"fingerprint_size": 40,
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
"dct_coefficient_count": 40,
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
10, 10, model_settings)
result_data, result_labels = audio_processor.get_unprocessed_data(
10, model_settings, "training")
self.assertEqual(10, len(result_data))
self.assertEqual(10, len(result_labels))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,176 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <fstream>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
// These are all common classes it's handy to reference with no namespace.
using tensorflow::Flag;
using tensorflow::Status;
using tensorflow::Tensor;
using tensorflow::int32;
using tensorflow::string;
namespace {
// Reads a model graph definition from disk, and creates a session object you
// can use to run it.
Status LoadGraph(const string& graph_file_name,
std::unique_ptr<tensorflow::Session>* session) {
tensorflow::GraphDef graph_def;
Status load_graph_status =
ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
if (!load_graph_status.ok()) {
return tensorflow::errors::NotFound("Failed to load compute graph at '",
graph_file_name, "'");
}
session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
Status session_create_status = (*session)->Create(graph_def);
if (!session_create_status.ok()) {
return session_create_status;
}
return Status::OK();
}
// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings.
Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
std::ifstream file(file_name);
if (!file) {
return tensorflow::errors::NotFound("Labels file ", file_name,
" not found.");
}
result->clear();
string line;
while (std::getline(file, line)) {
result->push_back(line);
}
return Status::OK();
}
// Analyzes the output of the graph to retrieve the highest scores and
// their positions in the tensor.
void GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
Tensor* out_indices, Tensor* out_scores) {
const Tensor& unsorted_scores_tensor = outputs[0];
auto unsorted_scores_flat = unsorted_scores_tensor.flat<float>();
std::vector<std::pair<int, float>> scores;
scores.reserve(unsorted_scores_flat.size());
for (int i = 0; i < unsorted_scores_flat.size(); ++i) {
scores.push_back(std::pair<int, float>({i, unsorted_scores_flat(i)}));
}
std::sort(scores.begin(), scores.end(),
[](const std::pair<int, float>& left,
const std::pair<int, float>& right) {
return left.second > right.second;
});
scores.resize(how_many_labels);
Tensor sorted_indices(tensorflow::DT_INT32, {how_many_labels});
Tensor sorted_scores(tensorflow::DT_FLOAT, {how_many_labels});
for (int i = 0; i < scores.size(); ++i) {
sorted_indices.flat<int>()(i) = scores[i].first;
sorted_scores.flat<float>()(i) = scores[i].second;
}
*out_indices = sorted_indices;
*out_scores = sorted_scores;
}
} // namespace
int main(int argc, char* argv[]) {
string wav = "";
string graph = "";
string labels = "";
string input_name = "wav_data";
string output_name = "labels_softmax";
int32 how_many_labels = 3;
std::vector<Flag> flag_list = {
Flag("wav", &wav, "audio file to be identified"),
Flag("graph", &graph, "model to be executed"),
Flag("labels", &labels, "path to file containing labels"),
Flag("input_name", &input_name, "name of input node in model"),
Flag("output_name", &output_name, "name of output node in model"),
Flag("how_many_labels", &how_many_labels, "number of results to show"),
};
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
LOG(ERROR) << usage;
return -1;
}
// We need to call this to set up global state for TensorFlow.
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return -1;
}
// First we load and initialize the model.
std::unique_ptr<tensorflow::Session> session;
Status load_graph_status = LoadGraph(graph, &session);
if (!load_graph_status.ok()) {
LOG(ERROR) << load_graph_status;
return -1;
}
std::vector<string> labels_list;
Status read_labels_status = ReadLabelsFile(labels, &labels_list);
if (!read_labels_status.ok()) {
LOG(ERROR) << read_labels_status;
return -1;
}
string wav_string;
Status read_wav_status = tensorflow::ReadFileToString(
tensorflow::Env::Default(), wav, &wav_string);
if (!read_wav_status.ok()) {
LOG(ERROR) << read_wav_status;
return -1;
}
Tensor wav_tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
wav_tensor.scalar<string>()() = wav_string;
// Actually run the audio through the model.
std::vector<Tensor> outputs;
Status run_status =
session->Run({{input_name, wav_tensor}}, {output_name}, {}, &outputs);
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed: " << run_status;
return -1;
}
Tensor indices;
Tensor scores;
GetTopLabels(outputs, how_many_labels, &indices, &scores);
tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
for (int pos = 0; pos < how_many_labels; ++pos) {
const int label_index = indices_flat(pos);
const float score = scores_flat(pos);
LOG(INFO) << labels_list[label_index] << " (" << label_index
<< "): " << score;
}
return 0;
}

View File

@ -0,0 +1,133 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Runs a trained audio graph against a WAVE file and reports the results.
The model, labels and .wav file specified in the arguments will be loaded, and
then the predictions from running the model against the audio data will be
printed to the console. This is a useful script for sanity checking trained
models, and as an example of how to use an audio model from Python.
Here's an example of running it:
python tensorflow/examples/speech_commands/label_wav.py \
--graph=/tmp/my_frozen_graph.pb \
--labels=/tmp/speech_commands_train/conv_labels.txt \
--wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
# pylint: disable=unused-import
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
# pylint: enable=unused-import
FLAGS = None
def load_graph(filename):
"""Unpersists graph from file as default graph."""
with tf.gfile.FastGFile(filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
def load_labels(filename):
"""Read in labels, one label per line."""
return [line.rstrip() for line in tf.gfile.GFile(filename)]
def run_graph(wav_data, labels, input_layer_name, output_layer_name,
num_top_predictions):
"""Runs the audio data through the graph and prints predictions."""
with tf.Session() as sess:
# Feed the audio data as input to the graph.
# predictions will contain a two-dimensional array, where one
# dimension represents the input image count, and the other has
# predictions per class
softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})
# Sort to show labels in order of confidence
top_k = predictions.argsort()[-num_top_predictions:][::-1]
for node_id in top_k:
human_string = labels[node_id]
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
return 0
def label_wav(wav, labels, graph, input_name, output_name, how_many_labels):
"""Loads the model and labels, and runs the inference to print predictions."""
if not wav or not tf.gfile.Exists(wav):
tf.logging.fatal('Audio file does not exist %s', wav)
if not labels or not tf.gfile.Exists(labels):
tf.logging.fatal('Labels file does not exist %s', labels)
if not graph or not tf.gfile.Exists(graph):
tf.logging.fatal('Graph file does not exist %s', graph)
labels_list = load_labels(labels)
# load graph, which is stored in the default session
load_graph(graph)
with open(wav, 'rb') as wav_file:
wav_data = wav_file.read()
run_graph(wav_data, labels_list, input_name, output_name, how_many_labels)
def main(_):
"""Entry point for script, converts flags to arguments."""
label_wav(FLAGS.wav, FLAGS.labels, FLAGS.graph, FLAGS.input_name,
FLAGS.output_name, FLAGS.how_many_labels)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--wav', type=str, default='', help='Audio file to be identified.')
parser.add_argument(
'--graph', type=str, default='', help='Model to use for identification.')
parser.add_argument(
'--labels', type=str, default='', help='Path to file containing labels.')
parser.add_argument(
'--input_name',
type=str,
default='wav_data:0',
help='Name of WAVE data input node in model.')
parser.add_argument(
'--output_name',
type=str,
default='labels_softmax:0',
help='Name of node outputting a prediction in the model.')
parser.add_argument(
'--how_many_labels',
type=int,
default=3,
help='Number of results to show.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,64 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for WAVE file labeling tool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from tensorflow.examples.speech_commands import label_wav
from tensorflow.python.platform import test
class LabelWavTest(test.TestCase):
def _getWavData(self):
with self.test_session() as sess:
sample_data = tf.zeros([1000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
return wav_data
def _saveTestWavFile(self, filename, wav_data):
with open(filename, "wb") as f:
f.write(wav_data)
def testLabelWav(self):
tmp_dir = self.get_temp_dir()
wav_data = self._getWavData()
wav_filename = os.path.join(tmp_dir, "wav_file.wav")
self._saveTestWavFile(wav_filename, wav_data)
input_name = "test_input"
output_name = "test_output"
graph_filename = os.path.join(tmp_dir, "test_graph.pb")
with tf.Session() as sess:
tf.placeholder(tf.string, name=input_name)
tf.zeros([1, 3], name=output_name)
with open(graph_filename, "wb") as f:
f.write(sess.graph.as_graph_def().SerializeToString())
labels_filename = os.path.join(tmp_dir, "test_labels.txt")
with open(labels_filename, "w") as f:
f.write("a\nb\nc\n")
label_wav.label_wav(wav_filename, labels_filename, graph_filename,
input_name + ":0", output_name + ":0", 3)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,378 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model definitions for simple speech recognition.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
window_size_ms, window_stride_ms,
dct_coefficient_count):
"""Calculates common settings needed for all models.
Args:
label_count: How many classes are to be recognized.
sample_rate: Number of audio samples per second.
clip_duration_ms: Length of each audio clip to be analyzed.
window_size_ms: Duration of frequency analysis window.
window_stride_ms: How far to move in time between frequency windows.
dct_coefficient_count: Number of frequency bins to use for analysis.
Returns:
Dictionary containing common settings.
"""
desired_samples = int(sample_rate * clip_duration_ms / 1000)
window_size_samples = int(sample_rate * window_size_ms / 1000)
window_stride_samples = int(sample_rate * window_stride_ms / 1000)
length_minus_window = (desired_samples - window_size_samples)
if length_minus_window < 0:
spectrogram_length = 0
else:
spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
fingerprint_size = dct_coefficient_count * spectrogram_length
return {
'desired_samples': desired_samples,
'window_size_samples': window_size_samples,
'window_stride_samples': window_stride_samples,
'spectrogram_length': spectrogram_length,
'dct_coefficient_count': dct_coefficient_count,
'fingerprint_size': fingerprint_size,
'label_count': label_count,
'sample_rate': sample_rate,
}
def create_model(fingerprint_input, model_settings, model_architecture,
is_training):
"""Builds a model of the requested architecture compatible with the settings.
There are many possible ways of deriving predictions from a spectrogram
input, so this function provides an abstract interface for creating different
kinds of models in a black-box way. You need to pass in a TensorFlow node as
the 'fingerprint' input, and this should output a batch of 1D features that
describe the audio. Typically this will be derived from a spectrogram that's
been run through an MFCC, but in theory it can be any feature vector of the
size specified in model_settings['fingerprint_size'].
The function will build the graph it needs in the current TensorFlow graph,
and return the tensorflow output that will contain the 'logits' input to the
softmax prediction process. If training flag is on, it will also return a
placeholder node that can be used to control the dropout amount.
See the implementations below for the possible model architectures that can be
requested.
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
model_settings: Dictionary of information about the model.
model_architecture: String specifying which kind of model to create.
is_training: Whether the model is going to be used for training.
Returns:
TensorFlow node outputting logits results, and optionally a dropout
placeholder.
Raises:
Exception: If the architecture type isn't recognized.
"""
if model_architecture == 'single_fc':
return create_single_fc_model(fingerprint_input, model_settings,
is_training)
elif model_architecture == 'conv':
return create_conv_model(fingerprint_input, model_settings, is_training)
elif model_architecture == 'low_latency_conv':
return create_low_latency_conv_model(fingerprint_input, model_settings,
is_training)
else:
raise Exception('model_architecture argument "' + model_architecture +
'" not recognized, should be one of "single_fc", "conv",' +
' or "low_latency_conv"')
def load_variables_from_checkpoint(sess, start_checkpoint):
"""Utility function to centralize checkpoint restoration.
Args:
sess: TensorFlow session.
start_checkpoint: Path to saved checkpoint on disk.
"""
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, start_checkpoint)
def create_single_fc_model(fingerprint_input, model_settings, is_training):
"""Builds a model with a single hidden fully-connected layer.
This is a very simple model with just one matmul and bias layer. As you'd
expect, it doesn't produce very accurate results, but it is very fast and
simple, so it's useful for sanity testing.
Here's the layout of the graph:
(fingerprint_input)
v
[MatMul]<-(weights)
v
[BiasAdd]<-(bias)
v
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
model_settings: Dictionary of information about the model.
is_training: Whether the model is going to be used for training.
Returns:
TensorFlow node outputting logits results, and optionally a dropout
placeholder.
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
weights = tf.Variable(
tf.truncated_normal([fingerprint_size, label_count], stddev=0.001))
bias = tf.Variable(tf.zeros([label_count]))
logits = tf.matmul(fingerprint_input, weights) + bias
if is_training:
return logits, dropout_prob
else:
return logits
def create_conv_model(fingerprint_input, model_settings, is_training):
"""Builds a standard convolutional model.
This is roughly the network labeled as 'cnn-trad-fpool3' in the
'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
Here's the layout of the graph:
(fingerprint_input)
v
[Conv2D]<-(weights)
v
[BiasAdd]<-(bias)
v
[Relu]
v
[MaxPool]
v
[Conv2D]<-(weights)
v
[BiasAdd]<-(bias)
v
[Relu]
v
[MaxPool]
v
[MatMul]<-(weights)
v
[BiasAdd]<-(bias)
v
This produces fairly good quality results, but can involve a large number of
weight parameters and computations. For a cheaper alternative from the same
paper with slightly less accuracy, see 'low_latency_conv' below.
During training, dropout nodes are introduced after each relu, controlled by a
placeholder.
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
model_settings: Dictionary of information about the model.
is_training: Whether the model is going to be used for training.
Returns:
TensorFlow node outputting logits results, and optionally a dropout
placeholder.
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
input_frequency_size = model_settings['dct_coefficient_count']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
first_filter_width = 8
first_filter_height = 20
first_filter_count = 64
first_weights = tf.Variable(
tf.truncated_normal(
[first_filter_height, first_filter_width, 1, first_filter_count],
stddev=0.01))
first_bias = tf.Variable(tf.zeros([first_filter_count]))
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1],
'SAME') + first_bias
first_relu = tf.nn.relu(first_conv)
if is_training:
first_dropout = tf.nn.dropout(first_relu, dropout_prob)
else:
first_dropout = first_relu
max_pool = tf.nn.max_pool(first_dropout, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
second_filter_width = 4
second_filter_height = 10
second_filter_count = 64
second_weights = tf.Variable(
tf.truncated_normal(
[
second_filter_height, second_filter_width, first_filter_count,
second_filter_count
],
stddev=0.01))
second_bias = tf.Variable(tf.zeros([second_filter_count]))
second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1],
'SAME') + second_bias
second_relu = tf.nn.relu(second_conv)
if is_training:
second_dropout = tf.nn.dropout(second_relu, dropout_prob)
else:
second_dropout = second_relu
second_conv_shape = second_dropout.get_shape()
second_conv_output_width = second_conv_shape[2]
second_conv_output_height = second_conv_shape[1]
second_conv_element_count = int(
second_conv_output_width * second_conv_output_height *
second_filter_count)
flattened_second_conv = tf.reshape(second_dropout,
[-1, second_conv_element_count])
label_count = model_settings['label_count']
final_fc_weights = tf.Variable(
tf.truncated_normal(
[second_conv_element_count, label_count], stddev=0.01))
final_fc_bias = tf.Variable(tf.zeros([label_count]))
final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
else:
return final_fc
def create_low_latency_conv_model(fingerprint_input, model_settings,
is_training):
"""Builds a convolutional model with low compute requirements.
This is roughly the network labeled as 'cnn-one-fstride4' in the
'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
Here's the layout of the graph:
(fingerprint_input)
v
[Conv2D]<-(weights)
v
[BiasAdd]<-(bias)
v
[Relu]
v
[MatMul]<-(weights)
v
[BiasAdd]<-(bias)
v
[MatMul]<-(weights)
v
[BiasAdd]<-(bias)
v
[MatMul]<-(weights)
v
[BiasAdd]<-(bias)
v
This produces slightly lower quality results than the 'conv' model, but needs
fewer weight parameters and computations.
During training, dropout nodes are introduced after the relu, controlled by a
placeholder.
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
model_settings: Dictionary of information about the model.
is_training: Whether the model is going to be used for training.
Returns:
TensorFlow node outputting logits results, and optionally a dropout
placeholder.
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
input_frequency_size = model_settings['dct_coefficient_count']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
first_filter_width = 8
first_filter_height = input_time_size
first_filter_count = 186
first_filter_stride_x = 1
first_filter_stride_y = 4
first_weights = tf.Variable(
tf.truncated_normal(
[first_filter_height, first_filter_width, 1, first_filter_count],
stddev=0.01))
first_bias = tf.Variable(tf.zeros([first_filter_count]))
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [
1, first_filter_stride_y, first_filter_stride_x, 1
], 'VALID') + first_bias
first_relu = tf.nn.relu(first_conv)
if is_training:
first_dropout = tf.nn.dropout(first_relu, dropout_prob)
else:
first_dropout = first_relu
first_conv_output_width = math.floor(
(input_frequency_size - first_filter_width + first_filter_stride_x) /
first_filter_stride_x)
first_conv_output_height = math.floor(
(input_time_size - first_filter_height + first_filter_stride_y) /
first_filter_stride_y)
first_conv_element_count = int(
first_conv_output_width * first_conv_output_height * first_filter_count)
flattened_first_conv = tf.reshape(first_dropout,
[-1, first_conv_element_count])
first_fc_output_channels = 128
first_fc_weights = tf.Variable(
tf.truncated_normal(
[first_conv_element_count, first_fc_output_channels], stddev=0.01))
first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 128
second_fc_weights = tf.Variable(
tf.truncated_normal(
[first_fc_output_channels, second_fc_output_channels], stddev=0.01))
second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
final_fc_weights = tf.Variable(
tf.truncated_normal(
[second_fc_output_channels, label_count], stddev=0.01))
final_fc_bias = tf.Variable(tf.zeros([label_count]))
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
else:
return final_fc

View File

@ -0,0 +1,86 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for speech commands models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.speech_commands import models
from tensorflow.python.platform import test
class ModelsTest(test.TestCase):
def testPrepareModelSettings(self):
self.assertIsNotNone(
models.prepare_model_settings(10, 16000, 1000, 20, 10, 40))
def testCreateModelConvTraining(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(fingerprint_input,
model_settings, "conv", True)
self.assertIsNotNone(logits)
self.assertIsNotNone(dropout_prob)
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelConvInference(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits = models.create_model(fingerprint_input, model_settings, "conv",
False)
self.assertIsNotNone(logits)
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
def testCreateModelLowLatencyConvTraining(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
fingerprint_input, model_settings, "low_latency_conv", True)
self.assertIsNotNone(logits)
self.assertIsNotNone(dropout_prob)
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelFullyConnectedTraining(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
fingerprint_input, model_settings, "single_fc", True)
self.assertIsNotNone(logits)
self.assertIsNotNone(dropout_prob)
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelBadArchitecture(self):
model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
with self.test_session():
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
with self.assertRaises(Exception) as e:
models.create_model(fingerprint_input, model_settings,
"bad_architecture", True)
self.assertTrue("not recognized" in str(e.exception))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,127 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/speech_commands/recognize_commands.h"
namespace tensorflow {
RecognizeCommands::RecognizeCommands(const std::vector<string>& labels,
int32 average_window_duration_ms,
float detection_threshold,
int32 suppression_ms, int32 minimum_count)
: labels_(labels),
average_window_duration_ms_(average_window_duration_ms),
detection_threshold_(detection_threshold),
suppression_ms_(suppression_ms),
minimum_count_(minimum_count) {
labels_count_ = labels.size();
previous_top_label_ = "_silence_";
previous_top_label_time_ = std::numeric_limits<int64>::min();
}
Status RecognizeCommands::ProcessLatestResults(const Tensor& latest_results,
const int64 current_time_ms,
string* found_command,
float* score,
bool* is_new_command) {
if (latest_results.NumElements() != labels_count_) {
return errors::InvalidArgument(
"The results for recognition should contain ", labels_count_,
" elements, but there are ", latest_results.NumElements());
}
if ((!previous_results_.empty()) &&
(current_time_ms < previous_results_.front().first)) {
return errors::InvalidArgument(
"Results must be fed in increasing time order, but received a "
"timestamp of ",
current_time_ms, " that was earlier than the previous one of ",
previous_results_.front().first);
}
// Add the latest results to the head of the queue.
previous_results_.push_back({current_time_ms, latest_results});
// Prune any earlier results that are too old for the averaging window.
const int64 time_limit = current_time_ms - average_window_duration_ms_;
while (previous_results_.front().first < time_limit) {
previous_results_.pop_front();
}
// If there are too few results, assume the result will be unreliable and
// bail.
const int64 how_many_results = previous_results_.size();
const int64 earliest_time = previous_results_.front().first;
const int64 samples_duration = current_time_ms - earliest_time;
if ((how_many_results < minimum_count_) ||
(samples_duration < (average_window_duration_ms_ / 4))) {
*found_command = previous_top_label_;
*score = 0.0f;
*is_new_command = false;
return Status::OK();
}
// Calculate the average score across all the results in the window.
std::vector<float> average_scores(labels_count_);
for (const auto& previous_result : previous_results_) {
const Tensor& scores_tensor = previous_result.second;
auto scores_flat = scores_tensor.flat<float>();
for (int i = 0; i < scores_flat.size(); ++i) {
average_scores[i] += scores_flat(i) / how_many_results;
}
}
// Sort the averaged results in descending score order.
std::vector<std::pair<int, float>> sorted_average_scores;
sorted_average_scores.reserve(labels_count_);
for (int i = 0; i < labels_count_; ++i) {
sorted_average_scores.push_back(
std::pair<int, float>({i, average_scores[i]}));
}
std::sort(sorted_average_scores.begin(), sorted_average_scores.end(),
[](const std::pair<int, float>& left,
const std::pair<int, float>& right) {
return left.second > right.second;
});
// See if the latest top score is enough to trigger a detection.
const int current_top_index = sorted_average_scores[0].first;
const string current_top_label = labels_[current_top_index];
const float current_top_score = sorted_average_scores[0].second;
// If we've recently had another label trigger, assume one that occurs too
// soon afterwards is a bad result.
int64 time_since_last_top;
if ((previous_top_label_ == "_silence_") ||
(previous_top_label_time_ == std::numeric_limits<int64>::min())) {
time_since_last_top = std::numeric_limits<int64>::max();
} else {
time_since_last_top = current_time_ms - previous_top_label_time_;
}
if ((current_top_score > detection_threshold_) &&
(current_top_label != previous_top_label_) &&
(time_since_last_top > suppression_ms_)) {
previous_top_label_ = current_top_label;
previous_top_label_time_ = current_time_ms;
*is_new_command = true;
} else {
*is_new_command = false;
}
*found_command = current_top_label;
*score = current_top_score;
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,79 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
#include <deque>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// This class is designed to apply a very primitive decoding model on top of the
// instantaneous results from running an audio recognition model on a single
// window of samples. It applies smoothing over time so that noisy individual
// label scores are averaged, increasing the confidence that apparent matches
// are real.
// To use it, you should create a class object with the configuration you
// want, and then feed results from running a TensorFlow model into the
// processing method. The timestamp for each subsequent call should be
// increasing from the previous, since the class is designed to process a stream
// of data over time.
class RecognizeCommands {
public:
// labels should be a list of the strings associated with each one-hot score.
// The window duration controls the smoothing. Longer durations will give a
// higher confidence that the results are correct, but may miss some commands.
// The detection threshold has a similar effect, with high values increasing
// the precision at the cost of recall. The minimum count controls how many
// results need to be in the averaging window before it's seen as a reliable
// average. This prevents erroneous results when the averaging window is
// initially being populated for example. The suppression argument disables
// further recognitions for a set time after one has been triggered, which can
// help reduce spurious recognitions.
explicit RecognizeCommands(const std::vector<string>& labels,
int32 average_window_duration_ms = 1000,
float detection_threshold = 0.2,
int32 suppression_ms = 500,
int32 minimum_count = 3);
// Call this with the results of running a model on sample data.
Status ProcessLatestResults(const Tensor& latest_results,
const int64 current_time_ms,
string* found_command, float* score,
bool* is_new_command);
private:
// Configuration
std::vector<string> labels_;
int32 average_window_duration_ms_;
float detection_threshold_;
int32 suppression_ms_;
int32 minimum_count_;
// Working variables
std::deque<std::pair<int64, Tensor>> previous_results_;
string previous_top_label_;
int64 labels_count_;
int64 previous_top_label_time_;
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_

View File

@ -0,0 +1,114 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/speech_commands/recognize_commands.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
TEST(RecognizeCommandsTest, Basic) {
RecognizeCommands recognize_commands({"_silence_", "a", "b"});
Tensor results(DT_FLOAT, {3});
test::FillValues<float>(&results, {1.0f, 0.0f, 0.0f});
string found_command;
float score;
bool is_new_command;
TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
results, 0, &found_command, &score, &is_new_command));
}
TEST(RecognizeCommandsTest, FindCommands) {
RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f);
Tensor results(DT_FLOAT, {3});
test::FillValues<float>(&results, {0.0f, 1.0f, 0.0f});
bool has_found_new_command = false;
string new_command;
for (int i = 0; i < 10; ++i) {
string found_command;
float score;
bool is_new_command;
int64 current_time_ms = 0 + (i * 100);
TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
results, current_time_ms, &found_command, &score, &is_new_command));
if (is_new_command) {
EXPECT_FALSE(has_found_new_command);
has_found_new_command = true;
new_command = found_command;
}
}
EXPECT_TRUE(has_found_new_command);
EXPECT_EQ("a", new_command);
test::FillValues<float>(&results, {0.0f, 0.0f, 1.0f});
has_found_new_command = false;
new_command = "";
for (int i = 0; i < 10; ++i) {
string found_command;
float score;
bool is_new_command;
int64 current_time_ms = 1000 + (i * 100);
TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
results, current_time_ms, &found_command, &score, &is_new_command));
if (is_new_command) {
EXPECT_FALSE(has_found_new_command);
has_found_new_command = true;
new_command = found_command;
}
}
EXPECT_TRUE(has_found_new_command);
EXPECT_EQ("b", new_command);
}
TEST(RecognizeCommandsTest, BadInputLength) {
RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f);
Tensor bad_results(DT_FLOAT, {2});
test::FillValues<float>(&bad_results, {1.0f, 0.0f});
string found_command;
float score;
bool is_new_command;
EXPECT_FALSE(recognize_commands
.ProcessLatestResults(bad_results, 0, &found_command, &score,
&is_new_command)
.ok());
}
TEST(RecognizeCommandsTest, BadInputTimes) {
RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f);
Tensor results(DT_FLOAT, {3});
test::FillValues<float>(&results, {1.0f, 0.0f, 0.0f});
string found_command;
float score;
bool is_new_command;
TF_EXPECT_OK(recognize_commands.ProcessLatestResults(
results, 100, &found_command, &score, &is_new_command));
EXPECT_FALSE(recognize_commands
.ProcessLatestResults(results, 0, &found_command, &score,
&is_new_command)
.ok());
}
} // namespace tensorflow

View File

@ -0,0 +1,310 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/*
Tool to create accuracy statistics from running an audio recognition model on a
continuous stream of samples.
This is designed to be an environment for running experiments on new models and
settings to understand the effects they will have in a real application. You
need to supply it with a long audio file containing sounds you want to recognize
and a text file listing the labels of each sound along with the time they occur.
With this information, and a frozen model, the tool will process the audio
stream, apply the model, and keep track of how many mistakes and successes the
model achieved.
The matched percentage is the number of sounds that were correctly classified,
as a percentage of the total number of sounds listed in the ground truth file.
A correct classification is when the right label is chosen within a short time
of the expected ground truth, where the time tolerance is controlled by the
'time_tolerance_ms' command line flag.
The wrong percentage is how many sounds triggered a detection (the classifier
figured out it wasn't silence or background noise), but the detected class was
wrong. This is also a percentage of the total number of ground truth sounds.
The false positive percentage is how many sounds were detected when there was
only silence or background noise. This is also expressed as a percentage of the
total number of ground truth sounds, though since it can be large it may go
above 100%.
The easiest way to get an audio file and labels to test with is by using the
'generate_streaming_test_wav' script. This will synthesize a test file with
randomly placed sounds and background noise, and output a text file with the
ground truth.
If you want to test natural data, you need to use a .wav with the same sample
rate as your model (often 16,000 samples per second), and note down where the
sounds occur in time. Save this information out as a comma-separated text file,
where the first column is the label and the second is the time in seconds from
the start of the file that it occurs.
Here's an example of how to run the tool:
bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \
--wav=/tmp/streaming_test_bg.wav \
--graph=/tmp/conv_frozen.pb \
--labels=/tmp/speech_commands_train/conv_labels.txt \
--ground_truth=/tmp/streaming_test_labels.txt --verbose \
--clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \
--suppression_ms=500 --time_tolerance_ms=1500
*/
#include <fstream>
#include <iomanip>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/wav/wav_io.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/examples/speech_commands/accuracy_utils.h"
#include "tensorflow/examples/speech_commands/recognize_commands.h"
// These are all common classes it's handy to reference with no namespace.
using tensorflow::Flag;
using tensorflow::Status;
using tensorflow::Tensor;
using tensorflow::int32;
using tensorflow::string;
namespace {
// Reads a model graph definition from disk, and creates a session object you
// can use to run it.
Status LoadGraph(const string& graph_file_name,
std::unique_ptr<tensorflow::Session>* session) {
tensorflow::GraphDef graph_def;
Status load_graph_status =
ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
if (!load_graph_status.ok()) {
return tensorflow::errors::NotFound("Failed to load compute graph at '",
graph_file_name, "'");
}
session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
Status session_create_status = (*session)->Create(graph_def);
if (!session_create_status.ok()) {
return session_create_status;
}
return Status::OK();
}
// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings.
Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
std::ifstream file(file_name);
if (!file) {
return tensorflow::errors::NotFound("Labels file '", file_name,
"' not found.");
}
result->clear();
string line;
while (std::getline(file, line)) {
result->push_back(line);
}
return Status::OK();
}
} // namespace
int main(int argc, char* argv[]) {
string wav = "";
string graph = "";
string labels = "";
string ground_truth = "";
string input_data_name = "decoded_sample_data:0";
string input_rate_name = "decoded_sample_data:1";
string output_name = "labels_softmax";
int32 clip_duration_ms = 1000;
int32 sample_stride_ms = 30;
int32 average_window_ms = 500;
int32 time_tolerance_ms = 750;
int32 suppression_ms = 1500;
float detection_threshold = 0.7f;
bool verbose = false;
std::vector<Flag> flag_list = {
Flag("wav", &wav, "audio file to be identified"),
Flag("graph", &graph, "model to be executed"),
Flag("labels", &labels, "path to file containing labels"),
Flag("ground_truth", &ground_truth,
"path to file containing correct times and labels of words in the "
"audio as <word>,<timestamp in ms> lines"),
Flag("input_data_name", &input_data_name,
"name of input data node in model"),
Flag("input_rate_name", &input_rate_name,
"name of input sample rate node in model"),
Flag("output_name", &output_name, "name of output node in model"),
Flag("clip_duration_ms", &clip_duration_ms,
"length of recognition window"),
Flag("average_window_ms", &average_window_ms,
"length of window to smooth results over"),
Flag("time_tolerance_ms", &time_tolerance_ms,
"maximum gap allowed between a recognition and ground truth"),
Flag("suppression_ms", &suppression_ms,
"how long to ignore others for after a recognition"),
Flag("sample_stride_ms", &sample_stride_ms,
"how often to run recognition"),
Flag("detection_threshold", &detection_threshold,
"what score is required to trigger detection of a word"),
Flag("verbose", &verbose, "whether to log extra debugging information"),
};
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
LOG(ERROR) << usage;
return -1;
}
// We need to call this to set up global state for TensorFlow.
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return -1;
}
// First we load and initialize the model.
std::unique_ptr<tensorflow::Session> session;
Status load_graph_status = LoadGraph(graph, &session);
if (!load_graph_status.ok()) {
LOG(ERROR) << load_graph_status;
return -1;
}
std::vector<string> labels_list;
Status read_labels_status = ReadLabelsFile(labels, &labels_list);
if (!read_labels_status.ok()) {
LOG(ERROR) << read_labels_status;
return -1;
}
std::vector<std::pair<string, int64>> ground_truth_list;
Status read_ground_truth_status =
tensorflow::ReadGroundTruthFile(ground_truth, &ground_truth_list);
if (!read_ground_truth_status.ok()) {
LOG(ERROR) << read_ground_truth_status;
return -1;
}
string wav_string;
Status read_wav_status = tensorflow::ReadFileToString(
tensorflow::Env::Default(), wav, &wav_string);
if (!read_wav_status.ok()) {
LOG(ERROR) << read_wav_status;
return -1;
}
std::vector<float> audio_data;
uint32 sample_count;
uint16 channel_count;
uint32 sample_rate;
Status decode_wav_status = tensorflow::wav::DecodeLin16WaveAsFloatVector(
wav_string, &audio_data, &sample_count, &channel_count, &sample_rate);
if (!decode_wav_status.ok()) {
LOG(ERROR) << decode_wav_status;
return -1;
}
if (channel_count != 1) {
LOG(ERROR) << "Only mono .wav files can be used, but input has "
<< channel_count << " channels.";
return -1;
}
const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000;
const int64 sample_stride_samples = (sample_stride_ms * sample_rate) / 1000;
Tensor audio_data_tensor(tensorflow::DT_FLOAT,
tensorflow::TensorShape({clip_duration_samples, 1}));
Tensor sample_rate_tensor(tensorflow::DT_INT32, tensorflow::TensorShape({}));
sample_rate_tensor.scalar<int32>()() = sample_rate;
tensorflow::RecognizeCommands recognize_commands(
labels_list, average_window_ms, detection_threshold, suppression_ms);
std::vector<std::pair<string, int64>> all_found_words;
tensorflow::StreamingAccuracyStats previous_stats;
const int64 audio_data_end = (sample_count - clip_duration_ms);
for (int64 audio_data_offset = 0; audio_data_offset < audio_data_end;
audio_data_offset += sample_stride_samples) {
const float* input_start = &(audio_data[audio_data_offset]);
const float* input_end = input_start + clip_duration_samples;
std::copy(input_start, input_end, audio_data_tensor.flat<float>().data());
// Actually run the audio through the model.
std::vector<Tensor> outputs;
Status run_status = session->Run({{input_data_name, audio_data_tensor},
{input_rate_name, sample_rate_tensor}},
{output_name}, {}, &outputs);
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed: " << run_status;
return -1;
}
const int64 current_time_ms = (audio_data_offset * 1000) / sample_rate;
string found_command;
float score;
bool is_new_command;
Status recognize_status = recognize_commands.ProcessLatestResults(
outputs[0], current_time_ms, &found_command, &score, &is_new_command);
if (!recognize_status.ok()) {
LOG(ERROR) << "Recognition processing failed: " << recognize_status;
return -1;
}
if (is_new_command && (found_command != "_silence_")) {
all_found_words.push_back({found_command, current_time_ms});
if (verbose) {
tensorflow::StreamingAccuracyStats stats;
tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words,
current_time_ms, time_tolerance_ms,
&stats);
int32 false_positive_delta = stats.how_many_false_positives -
previous_stats.how_many_false_positives;
int32 correct_delta = stats.how_many_correct_words -
previous_stats.how_many_correct_words;
int32 wrong_delta =
stats.how_many_wrong_words - previous_stats.how_many_wrong_words;
string recognition_state;
if (false_positive_delta == 1) {
recognition_state = " (False Positive)";
} else if (correct_delta == 1) {
recognition_state = " (Correct)";
} else if (wrong_delta == 1) {
recognition_state = " (Wrong)";
} else {
LOG(ERROR) << "Unexpected state in statistics";
}
LOG(INFO) << current_time_ms << "ms: " << found_command << ": " << score
<< recognition_state;
previous_stats = stats;
tensorflow::PrintAccuracyStats(stats);
}
}
}
tensorflow::StreamingAccuracyStats stats;
tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, -1,
time_tolerance_ms, &stats);
tensorflow::PrintAccuracyStats(stats);
return 0;
}

View File

@ -0,0 +1,427 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Simple speech recognition to spot a limited number of keywords.
This is a self-contained example script that will train a very basic audio
recognition model in TensorFlow. It can download the necessary training data,
and runs with reasonable defaults to train within a few hours even only using a
CPU. For more information see http://tensorflow.org/tutorials/audio_recognition.
It is intended as an introduction to using neural networks for audio
recognition, and is not a full speech recognition system. For more advanced
speech systems, I recommend looking into Kaldi. This network uses a keyword
detection style to spot discrete words from a small vocabulary, consisting of
"yes", "no", "up", "down", "left", "right", "on", "off", "stop", and "go".
To run the training process, use:
bazel run tensorflow/examples/speech_commands:train
This will write out checkpoints to /tmp/speech_commands_train/, and will
download over 1GB of open source training data, so you'll need enough free space
and a good internet connection. The default data is a collection of thousands of
one-second .wav files, each containing one spoken word. This data set is
collected from https://aiyprojects.withgoogle.com/open_speech_recording, please
consider contributing to help improve this and other models!
As training progresses, it will print out its accuracy metrics, which should
rise above 90% by the end. Once it's complete, you can run the freeze script to
get a binary GraphDef that you can easily deploy on mobile applications.
If you want to train on your own data, you'll need to create .wavs with your
recordings, all at a consistent length, and then arrange them into subfolders
organized by label. For example, here's a possible file structure:
my_wavs >
up >
audio_0.wav
audio_1.wav
down >
audio_2.wav
audio_3.wav
other>
audio_4.wav
audio_5.wav
You'll also need to tell the script what labels to look for, using the
`--wanted_words` argument. In this case, 'up,down' might be what you want, and
the audio in the 'other' folder would be used to train an 'unknown' category.
To pull this all together, you'd run:
bazel run tensorflow/examples/speech_commands:train -- \
--data_dir=my_wavs --wanted_words=up,down
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import sys
import numpy as np
import tensorflow as tf
import input_data
import models
from tensorflow.python.platform import gfile
FLAGS = None
def main(_):
# We want to see all the logging messages for this tutorial.
tf.logging.set_verbosity(tf.logging.INFO)
# Start a new TensorFlow session.
sess = tf.InteractiveSession()
# Begin by making sure we have the training data we need. If you already have
# training data of your own, use `--data_url= ` on the command line to avoid
# downloading.
model_settings = models.prepare_model_settings(
len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
audio_processor = input_data.AudioProcessor(
FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage,
FLAGS.unknown_percentage,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
FLAGS.testing_percentage, model_settings)
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
# Figure out the learning rates for each training phase. Since it's often
# effective to have high learning rates at the start of training, followed by
# lower levels towards the end, the number of steps and learning rates can be
# specified as comma-separated lists to define the rate at each stage. For
# example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
# will run 13,000 training loops in total, with a rate of 0.001 for the first
# 10,000, and 0.0001 for the final 3,000.
training_steps_list = map(int, FLAGS.how_many_training_steps.split(','))
learning_rates_list = map(float, FLAGS.learning_rate.split(','))
if len(training_steps_list) != len(learning_rates_list):
raise Exception(
'--how_many_training_steps and --learning_rate must be equal length '
'lists, but are %d and %d long instead' % (len(training_steps_list),
len(learning_rates_list)))
fingerprint_input = tf.placeholder(
tf.float32, [None, fingerprint_size], name='fingerprint_input')
logits, dropout_prob = models.create_model(
fingerprint_input,
model_settings,
FLAGS.model_architecture,
is_training=True)
# Define loss and optimizer
ground_truth_input = tf.placeholder(
tf.float32, [None, label_count], name='groundtruth_input')
# Optionally we can add runtime checks to spot when NaNs or other symptoms of
# numerical errors start occurring during training.
control_dependencies = []
if FLAGS.check_nans:
checks = tf.add_check_numerics_ops()
control_dependencies = [checks]
# Create the back propagation and training evaluation machinery in the graph.
with tf.name_scope('cross_entropy'):
cross_entropy_mean = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
labels=ground_truth_input, logits=logits))
tf.summary.scalar('cross_entropy', cross_entropy_mean)
with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
learning_rate_input = tf.placeholder(
tf.float32, [], name='learning_rate_input')
train_step = tf.train.GradientDescentOptimizer(
learning_rate_input).minimize(cross_entropy_mean)
predicted_indices = tf.argmax(logits, 1)
expected_indices = tf.argmax(ground_truth_input, 1)
correct_prediction = tf.equal(predicted_indices, expected_indices)
confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', evaluation_step)
global_step = tf.contrib.framework.get_or_create_global_step()
increment_global_step = tf.assign(global_step, global_step + 1)
saver = tf.train.Saver(tf.global_variables())
# Merge all the summaries and write them out to /tmp/retrain_logs (by default)
merged_summaries = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
sess.graph)
validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
tf.global_variables_initializer().run()
start_step = 1
if FLAGS.start_checkpoint:
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
start_step = global_step.eval(session=sess)
tf.logging.info('Training from step: %d ', start_step)
# Save graph.pbtxt.
tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
FLAGS.model_architecture + '.pbtxt')
# Save list of words.
with gfile.GFile(
os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'),
'w') as f:
f.write('\n'.join(audio_processor.words_list))
# Training loop.
training_steps_max = np.sum(training_steps_list)
for training_step in xrange(start_step, training_steps_max + 1):
# Figure out what the current learning rate is.
training_steps_sum = 0
for i in range(len(training_steps_list)):
training_steps_sum += training_steps_list[i]
if training_step <= training_steps_sum:
learning_rate_value = learning_rates_list[i]
break
# Pull the audio samples we'll use for training.
train_fingerprints, train_ground_truth = audio_processor.get_data(
FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
FLAGS.background_volume, time_shift_samples, 'training', sess)
# Run the graph with this batch of training data.
train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
[
merged_summaries, evaluation_step, cross_entropy_mean, train_step,
increment_global_step
],
feed_dict={
fingerprint_input: train_fingerprints,
ground_truth_input: train_ground_truth,
learning_rate_input: learning_rate_value,
dropout_prob: 0.5
})
train_writer.add_summary(train_summary, training_step)
tf.logging.info('Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
(training_step, learning_rate_value, train_accuracy * 100,
cross_entropy_value))
is_last_step = (training_step == training_steps_max)
if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
set_size = audio_processor.set_size('validation')
total_accuracy = 0
total_conf_matrix = None
for i in xrange(0, set_size, FLAGS.batch_size):
validation_fingerprints, validation_ground_truth = (
audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0,
0.0, 0, 'validation', sess))
# Run a validation step and capture training summaries for TensorBoard
# with the `merged` op.
validation_summary, validation_accuracy, conf_matrix = sess.run(
[merged_summaries, evaluation_step, confusion_matrix],
feed_dict={
fingerprint_input: validation_fingerprints,
ground_truth_input: validation_ground_truth,
dropout_prob: 1.0
})
validation_writer.add_summary(validation_summary, training_step)
batch_size = min(FLAGS.batch_size, set_size - i)
total_accuracy += (validation_accuracy * batch_size) / set_size
if total_conf_matrix is None:
total_conf_matrix = conf_matrix
else:
total_conf_matrix += conf_matrix
tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
(training_step, total_accuracy * 100, set_size))
# Save the model checkpoint periodically.
if (training_step % FLAGS.save_step_interval == 0 or
training_step == training_steps_max):
checkpoint_path = os.path.join(FLAGS.train_dir,
FLAGS.model_architecture + '.ckpt')
tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step)
saver.save(sess, checkpoint_path, global_step=training_step)
set_size = audio_processor.set_size('testing')
tf.logging.info('set_size=%d', set_size)
total_accuracy = 0
total_conf_matrix = None
for i in xrange(0, set_size, FLAGS.batch_size):
test_fingerprints, test_ground_truth = audio_processor.get_data(
FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
test_accuracy, conf_matrix = sess.run(
[evaluation_step, confusion_matrix],
feed_dict={
fingerprint_input: test_fingerprints,
ground_truth_input: test_ground_truth,
dropout_prob: 1.0
})
batch_size = min(FLAGS.batch_size, set_size - i)
total_accuracy += (test_accuracy * batch_size) / set_size
if total_conf_matrix is None:
total_conf_matrix = conf_matrix
else:
total_conf_matrix += conf_matrix
tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100,
set_size))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_url',
type=str,
# pylint: disable=line-too-long
default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
# pylint: enable=line-too-long
help='Location of speech training data archive on the web.')
parser.add_argument(
'--data_dir',
type=str,
default='/tmp/speech_dataset/',
help="""\
Where to download the speech training data to.
""")
parser.add_argument(
'--background_volume',
type=float,
default=0.1,
help="""\
How loud the background noise should be, between 0 and 1.
""")
parser.add_argument(
'--background_frequency',
type=float,
default=0.8,
help="""\
How many of the training samples have background noise mixed in.
""")
parser.add_argument(
'--silence_percentage',
type=float,
default=10.0,
help="""\
How much of the training data should be silence.
""")
parser.add_argument(
'--unknown_percentage',
type=float,
default=10.0,
help="""\
How much of the training data should be unknown words.
""")
parser.add_argument(
'--time_shift_ms',
type=float,
default=100.0,
help="""\
Range to randomly shift the training audio by in time.
""")
parser.add_argument(
'--testing_percentage',
type=int,
default=10,
help='What percentage of wavs to use as a test set.')
parser.add_argument(
'--validation_percentage',
type=int,
default=10,
help='What percentage of wavs to use as a validation set.')
parser.add_argument(
'--sample_rate',
type=int,
default=16000,
help='Expected sample rate of the wavs',)
parser.add_argument(
'--clip_duration_ms',
type=int,
default=1000,
help='Expected duration in milliseconds of the wavs',)
parser.add_argument(
'--window_size_ms',
type=float,
default=20.0,
help='How long each spectrogram timeslice is',)
parser.add_argument(
'--window_stride_ms',
type=float,
default=10.0,
help='How long each spectrogram timeslice is',)
parser.add_argument(
'--dct_coefficient_count',
type=int,
default=40,
help='How many bins to use for the MFCC fingerprint',)
parser.add_argument(
'--how_many_training_steps',
type=str,
default='15000,3000',
help='How many training loops to run',)
parser.add_argument(
'--eval_step_interval',
type=int,
default=400,
help='How often to evaluate the training results.')
parser.add_argument(
'--learning_rate',
type=str,
default='0.001,0.0001',
help='How large a learning rate to use when training.')
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='How many items to train with at once',)
parser.add_argument(
'--summaries_dir',
type=str,
default='/tmp/retrain_logs',
help='Where to save summary logs for TensorBoard.')
parser.add_argument(
'--wanted_words',
type=str,
default='yes,no,up,down,left,right,on,off,stop,go',
help='Words to use (others will be added to an unknown label)',)
parser.add_argument(
'--train_dir',
type=str,
default='/tmp/speech_commands_train',
help='Directory to write event logs and checkpoint.')
parser.add_argument(
'--save_step_interval',
type=int,
default=100,
help='Save model checkpoint every save_steps.')
parser.add_argument(
'--start_checkpoint',
type=str,
default='',
help='If specified, restore this pretrained model before any training.')
parser.add_argument(
'--model_architecture',
type=str,
default='conv',
help='What model architecture to use')
parser.add_argument(
'--check_nans',
type=bool,
default=False,
help='Whether to check for invalid numbers during processing')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -1133,6 +1133,15 @@ tf_gen_op_wrapper_private_py(
],
)
tf_gen_op_wrapper_private_py(
name = "audio_ops_gen",
require_shape_functions = True,
visibility = [
"//learning/brain/python/ops:__pkg__",
"//tensorflow/contrib/framework:__pkg__",
],
)
tf_gen_op_wrapper_private_py(
name = "candidate_sampling_ops_gen",
visibility = ["//learning/brain/python/ops:__pkg__"],