From 0c6fd1703eb8f990c8b071471b0105339ccf821d Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Fri, 11 Aug 2017 14:30:43 -0700 Subject: [PATCH] Speech keyword detector tutorial Adds a basic training script for a simple audio model to our examples. See third_party/docs_src/tutorials/audio_recognition.md for full documentation PiperOrigin-RevId: 165025732 --- tensorflow/BUILD | 1 + tensorflow/contrib/framework/BUILD | 2 + .../contrib/framework/python/ops/audio_ops.py | 36 ++ tensorflow/core/lib/wav/wav_io.cc | 49 +- tensorflow/core/ops/audio_ops.cc | 4 +- .../docs_src/tutorials/audio_recognition.md | 551 ++++++++++++++++++ tensorflow/examples/speech_commands/BUILD | 258 ++++++++ tensorflow/examples/speech_commands/README.md | 4 + .../speech_commands/accuracy_utils.cc | 138 +++++ .../examples/speech_commands/accuracy_utils.h | 60 ++ .../speech_commands/accuracy_utils_test.cc | 59 ++ tensorflow/examples/speech_commands/freeze.py | 167 ++++++ .../examples/speech_commands/freeze_test.py | 38 ++ .../generate_streaming_test_wav.py | 281 +++++++++ .../generate_streaming_test_wav_test.py | 39 ++ .../examples/speech_commands/input_data.py | 532 +++++++++++++++++ .../speech_commands/input_data_test.py | 212 +++++++ .../examples/speech_commands/label_wav.cc | 176 ++++++ .../examples/speech_commands/label_wav.py | 133 +++++ .../speech_commands/label_wav_test.py | 64 ++ tensorflow/examples/speech_commands/models.py | 378 ++++++++++++ .../examples/speech_commands/models_test.py | 86 +++ .../speech_commands/recognize_commands.cc | 127 ++++ .../speech_commands/recognize_commands.h | 79 +++ .../recognize_commands_test.cc | 114 ++++ .../test_streaming_accuracy.cc | 310 ++++++++++ tensorflow/examples/speech_commands/train.py | 427 ++++++++++++++ tensorflow/python/BUILD | 9 + 28 files changed, 4321 insertions(+), 13 deletions(-) create mode 100644 tensorflow/contrib/framework/python/ops/audio_ops.py create mode 100644 tensorflow/docs_src/tutorials/audio_recognition.md create mode 100644 tensorflow/examples/speech_commands/BUILD create mode 100644 tensorflow/examples/speech_commands/README.md create mode 100644 tensorflow/examples/speech_commands/accuracy_utils.cc create mode 100644 tensorflow/examples/speech_commands/accuracy_utils.h create mode 100644 tensorflow/examples/speech_commands/accuracy_utils_test.cc create mode 100644 tensorflow/examples/speech_commands/freeze.py create mode 100644 tensorflow/examples/speech_commands/freeze_test.py create mode 100644 tensorflow/examples/speech_commands/generate_streaming_test_wav.py create mode 100644 tensorflow/examples/speech_commands/generate_streaming_test_wav_test.py create mode 100644 tensorflow/examples/speech_commands/input_data.py create mode 100644 tensorflow/examples/speech_commands/input_data_test.py create mode 100644 tensorflow/examples/speech_commands/label_wav.cc create mode 100644 tensorflow/examples/speech_commands/label_wav.py create mode 100644 tensorflow/examples/speech_commands/label_wav_test.py create mode 100644 tensorflow/examples/speech_commands/models.py create mode 100644 tensorflow/examples/speech_commands/models_test.py create mode 100644 tensorflow/examples/speech_commands/recognize_commands.cc create mode 100644 tensorflow/examples/speech_commands/recognize_commands.h create mode 100644 tensorflow/examples/speech_commands/recognize_commands_test.cc create mode 100644 tensorflow/examples/speech_commands/test_streaming_accuracy.cc create mode 100644 tensorflow/examples/speech_commands/train.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9e372bf052a..71f6d83da3f 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -370,6 +370,7 @@ filegroup( "//tensorflow/examples/label_image:all_files", "//tensorflow/examples/learn:all_files", "//tensorflow/examples/saved_model:all_files", + "//tensorflow/examples/speech_commands:all_files", "//tensorflow/examples/tutorials/estimators:all_files", "//tensorflow/examples/tutorials/mnist:all_files", "//tensorflow/examples/tutorials/word2vec:all_files", diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index a953c04c1a9..84c371ec3b5 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -28,6 +28,7 @@ tf_custom_op_py_library( "python/framework/tensor_util.py", "python/ops/__init__.py", "python/ops/arg_scope.py", + "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", @@ -50,6 +51,7 @@ tf_custom_op_py_library( ":gen_variable_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", + "//tensorflow/python:audio_ops_gen", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/contrib/framework/python/ops/audio_ops.py b/tensorflow/contrib/framework/python/ops/audio_ops.py new file mode 100644 index 00000000000..0aac269b90f --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/audio_ops.py @@ -0,0 +1,36 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# pylint: disable=g-short-docstring-punctuation +"""Audio processing and decoding ops. + +@@decode_wav +@@encode_wav +@@audio_spectrogram +@@mfcc +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_audio_ops import * +# pylint: enable=wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(__name__, []) diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc index 79918690dbb..1db4746c89e 100644 --- a/tensorflow/core/lib/wav/wav_io.cc +++ b/tensorflow/core/lib/wav/wav_io.cc @@ -118,6 +118,17 @@ Status ReadValue(const string& data, T* value, int* offset) { return Status::OK(); } +Status ReadString(const string& data, int expected_length, string* value, + int* offset) { + const int new_offset = *offset + expected_length; + if (new_offset > data.size()) { + return errors::InvalidArgument("Data too short when trying to read string"); + } + *value = string(data.begin() + *offset, data.begin() + new_offset); + *offset = new_offset; + return Status::OK(); +} + } // namespace Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, @@ -254,17 +265,33 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string, // Skip over this unused section. offset += 2; } - TF_RETURN_IF_ERROR(ExpectText(wav_string, kDataChunkId, &offset)); - uint32 data_size; - TF_RETURN_IF_ERROR(ReadValue(wav_string, &data_size, &offset)); - *sample_count = data_size / bytes_per_sample; - const uint32 data_count = *sample_count * *channel_count; - float_values->resize(data_count); - for (int i = 0; i < data_count; ++i) { - int16 single_channel_value = 0; - TF_RETURN_IF_ERROR( - ReadValue(wav_string, &single_channel_value, &offset)); - (*float_values)[i] = Int16SampleToFloat(single_channel_value); + + bool was_data_found = false; + while (offset < wav_string.size()) { + string chunk_id; + TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset)); + uint32 chunk_size; + TF_RETURN_IF_ERROR(ReadValue(wav_string, &chunk_size, &offset)); + if (chunk_id == kDataChunkId) { + if (was_data_found) { + return errors::InvalidArgument("More than one data chunk found in WAV"); + } + was_data_found = true; + *sample_count = chunk_size / bytes_per_sample; + const uint32 data_count = *sample_count * *channel_count; + float_values->resize(data_count); + for (int i = 0; i < data_count; ++i) { + int16 single_channel_value = 0; + TF_RETURN_IF_ERROR( + ReadValue(wav_string, &single_channel_value, &offset)); + (*float_values)[i] = Int16SampleToFloat(single_channel_value); + } + } else { + offset += chunk_size; + } + } + if (!was_data_found) { + return errors::InvalidArgument("No data chunk found in WAV"); } return Status::OK(); } diff --git a/tensorflow/core/ops/audio_ops.cc b/tensorflow/core/ops/audio_ops.cc index 91e81f2579c..5e4dba604ea 100644 --- a/tensorflow/core/ops/audio_ops.cc +++ b/tensorflow/core/ops/audio_ops.cc @@ -62,7 +62,7 @@ Status DecodeWavShapeFn(InferenceContext* c) { Status EncodeWavShapeFn(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); c->set_output(0, c->Scalar()); return Status::OK(); } @@ -104,7 +104,7 @@ Status MfccShapeFn(InferenceContext* c) { ShapeHandle spectrogram; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram)); ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); int32 dct_coefficient_count; TF_RETURN_IF_ERROR( diff --git a/tensorflow/docs_src/tutorials/audio_recognition.md b/tensorflow/docs_src/tutorials/audio_recognition.md new file mode 100644 index 00000000000..57d3ebb9968 --- /dev/null +++ b/tensorflow/docs_src/tutorials/audio_recognition.md @@ -0,0 +1,551 @@ +# How to Train a Simple Audio Recognition Network + +This tutorial will show you how to build a basic speech recognition network that +recognizes ten different words. It's important to know that real speech and +audio recognition systems are much more complex, but like MNIST for images, it +should give you a basic understanding of the techniques involved. Once you've +completed this tutorial, you'll have a model that tries to classify a one second +audio clip as either silence, an unknown word, "yes", "no", "up", "down", +"left", "right", "on", "off", "stop", or "go". You'll also be able to take this +model and run it in an Android application. + +## Preparation + +You should make sure you have TensorFlow installed, and since the script +downloads over 1GB of training data, you'll need a good internet connection and +enough free space on your machine. The training process itself can take several +hours, so make sure you have a machine available for that long. + +## Training + +To begin the training process, go to the TensorFlow source tree and run: + +```bash +python tensorflow/examples/speech_commands/train.py +``` + +The script will start off by downloading the [Speech Commands +dataset](https://download.tensorflow.org/data/speech_commands_v0.01.tar.gz), +which consists of 65,000 WAVE audio files of people saying thirty different +words. This data was collected by Google and released under a CC BY license, and +you can help improve it by [contributing five minutes of your own +voice](https://aiyprojects.withgoogle.com/open_speech_recording). The archive is +over 1GB, so this part may take a while, but you should see progress logs, and +once it's been downloaded once you won't need to do this step again. + +Once the downloading has completed, you'll see logging information that looks +like this: + +``` +I0730 16:53:44.766740 55030 train.py:176] Training from step: 1 +I0730 16:53:47.289078 55030 train.py:217] Step #1: rate 0.001000, accuracy 7.0%, cross entropy 2.611571 +``` + +This shows that the initialization process is done and the training loop has +begun. You'll see that it outputs information for every training step. Here's a +break down of what it means: + +`Step #1` shows that we're on the first step of the training loop. In this case +there are going to be 18,000 steps in total, so you can look at the step number +to get an idea of how close it is to finishing. + +`rate 0.001000` is the learning rate that's controlling the speed of the +network's weight updates. Early on this is a comparatively high number (0.001), +but for later training cycles it will be reduced 10x, to 0.0001. + +`accuracy 7.0%` is the how many classes were correctly predicted on this +training step. This value will often fluctuate a lot, but should increase on +average as training progresses. The model outputs an array of numbers, one for +each label, and each number is the predicted likelihood of the input being that +class. The predicted label is picked by choosing the entry with the highest +score. The scores are always between zero and one, with higher values +representing more confidence in the result. + +`cross entropy 2.611571` is the result of the loss function that we're using to +guide the training process. This is a score that's obtained by comparing the +vector of scores from the current training run to the correct labels, and this +should trend downwards during training. + +After a hundred steps, you should see a line like this: + +`I0730 16:54:41.813438 55030 train.py:252] Saving to +"/tmp/speech_commands_train/conv.ckpt-100"` + +This is saving out the current trained weights to a checkpoint file. If your +training script gets interrupted, you can look for the last saved checkpoint and +then restart the script with +`--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-100` as a command line +argument to start from that point. + +## Confusion Matrix + +After four hundred steps, this information will be logged: + +``` +I0730 16:57:38.073667 55030 train.py:243] Confusion Matrix: + [[258 0 0 0 0 0 0 0 0 0 0 0] + [ 7 6 26 94 7 49 1 15 40 2 0 11] + [ 10 1 107 80 13 22 0 13 10 1 0 4] + [ 1 3 16 163 6 48 0 5 10 1 0 17] + [ 15 1 17 114 55 13 0 9 22 5 0 9] + [ 1 1 6 97 3 87 1 12 46 0 0 10] + [ 8 6 86 84 13 24 1 9 9 1 0 6] + [ 9 3 32 112 9 26 1 36 19 0 0 9] + [ 8 2 12 94 9 52 0 6 72 0 0 2] + [ 16 1 39 74 29 42 0 6 37 9 0 3] + [ 15 6 17 71 50 37 0 6 32 2 1 9] + [ 11 1 6 151 5 42 0 8 16 0 0 20]] +``` + +The first section is a [confusion +matrix](https://www.tensorflow.org/api_docs/python/tf/confusion_matrix). To +understand what it means, you first need to know the labels being used, which in +this case are "_silence_", "_unknown_", "yes", "no", "up", "down", "left", +"right", "on", "off", "stop", and "go". Each column represents a set of samples +that were predicted to be each label, so the first column represents all the +clips that were predicted to be silence, the second all those that were +predicted to be unknown words, the third "yes", and so on. + +Each row represents clips by their correct, ground truth labels. The first row +is all the clips that were silence, the second clips that were unknown words, +the third "yes", etc. + +This matrix can be more useful than just a single accuracy score because it +gives a good summary of what mistakes the network is making. In this example you +can see that all of the entries in the first row are zero, apart from the +initial one. Because the first row is all the clips that are actually silence, +this means that none of them were mistakenly labeled as words, so we have no +false negatives for silence. This shows the network is already getting pretty +good at distinguishing silence from words. + +If we look down the first column though, we see a lot of non-zero values. The +column represents all the clips that were predicted to be silence, so positive +numbers outside of the first cell are errors. This means that some clips of real +spoken words are actually being predicted to be silence, so we do have quite a +few false positives. + +A perfect model would produce a confusion matrix where all of the entries were +zero apart from a diagonal line through the center. Spotting deviations from +that pattern can help you figure out how the model is most easily confused, and +once you've identified the problems you can address them by adding more data or +cleaning up categories. + +## Validation + +After the confusion matrix, you should see a line like this: + +`I0730 16:57:38.073777 55030 train.py:245] Step 400: Validation accuracy = 26.3% +(N=3093)` + +It's good practice to separate your data set into three categories. The largest +(in this case roughly 80% of the data) is used for training the network, a +smaller set (10% here, known as "validation") is reserved for evaluation of the +accuracy during training, and another set (the last 10%, "testing") is used to +evaluate the accuracy once after the training is complete. + +The reason for this split is that there's always a danger that networks will +start memorizing their inputs during training. By keeping the validation set +separate, you can ensure that the model works with data it's never seen before. +The testing set is an additional safeguard to make sure that you haven't just +been tweaking your model in a way that happens to work for both the training and +validation sets, but not a broader range of inputs. + +The training script automatically separates the data set into these three +categories, and the logging line above shows the accuracy of model when run on +the validation set. Ideally, this should stick fairly close to the training +accuracy. If the training accuracy increases but the validation doesn't, that's +a sign that overfitting is occurring, and your model is only learning things +about the training clips, not broader patterns that generalize. + +## Tensorboard + +A good way to visualize how the training is progressing is using Tensorboard. By +default, the script saves out events to /tmp/retrain_logs, and you can load +these by running: + +`tensorboard --logdir /tmp/retrain_logs` + +Then navigate to [http://localhost:6006](http://localhost:6006) in your browser, +and you'll see charts and graphs showing your models progress. + +
+ +
+ +## Training Finished + +After a few hours of training (depending on your machine's speed), the script +should have completed all 18,000 steps. It will print out a final confusion +matrix, along with an accuracy score, all run on the testing set. With the +default settings, you should see an accuracy of between 85% and 90%. + +Because audio recognition is particularly useful on mobile devices, next we'll +export it to a compact format that's easy to work with on those platforms. To do +that, run this command line: + +``` +python tensorflow/examples/speech_commands/freeze.py \ +--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-18000 \ +--output_file=/tmp/my_frozen_graph.pb +``` + +Once the frozen model has been created, you can test it with the `label_wav.py` +script, like this: + +``` +python tensorflow/examples/speech_commands/label_wav.py \ +--graph=/tmp/my_frozen_graph.pb \ +--labels=/tmp/speech_commands_train/conv_labels.txt \ +--wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav +``` + +This should print out three labels: + +``` +left (score = 0.81477) +right (score = 0.14139) +_unknown_ (score = 0.03808) +``` + +Hopefully "left" is the top score since that's the correct label, but since the +training is random it may not for the first file you try. Experiment with some +of the other .wav files in that same folder to see how well it does. + +The scores are between zero and one, and higher values mean the model is more +confident in its prediction. + +## How does this Model Work? + +The architecture used in this tutorial is based on some described in the paper +[Convolutional Neural Networks for Small-footprint Keyword +Spotting](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf). +It was chosen because it's comparatively simple, quick to train, and easy to +understand, rather than being state of the art. There are lots of different +approaches to building neural network models to work with audio, including +[recurrent networks](https://svds.com/tensorflow-rnn-tutorial/) or [dilated +(atrous) +convolutions](https://deepmind.com/blog/wavenet-generative-model-raw-audio/). +This tutorial is based on the kind of convolutional network that will feel very +familiar to anyone who's worked with image recognition. That may seem surprising +at first though, since audio is inherently a one-dimensional continuous signal +across time, not a 2D spatial problem. + +We solve that issue by defining a window of time we believe our spoken words +should fit into, and converting the audio signal in that window into an image. +This is done by grouping the incoming audio samples into short segments, just a +few milliseconds long, and calculating the strength of the frequencies across a +set of bands. Each set of frequency strengths from a segment is treated as a +vector of numbers, and those vectors are arranged in time order to form a +two-dimensional array. This array of values can then be treated like a +single-channel image, and is known as a +[spectrogram](https://en.wikipedia.org/wiki/Spectrogram). If you want to view +what kind of image an audio sample produces, you can run the `wav_to_spectrogram +tool: + +``` +bazel run tensorflow/examples/wav_to_spectrogram:wav_to_spectrogram -- \ +--input_wav=/tmp/speech_dataset/happy/ab00c4b2_nohash_0.wav \ +--output_png=/tmp/spectrogram.png +``` + +If you open up `/tmp/spectrogram.png` you should see something like this: + +
+ +
+ +Because of TensorFlow's memory order, time in this image is increasing from top +to bottom, with frequencies going from left to right, unlike the usual +convention for spectrograms where time is left to right. You should be able to +see a couple of distinct parts, with the first syllable "Ha" distinct from +"ppy". + +Because the human ear is more sensitive to some frequencies than others, it's +been traditional in speech recognition to do further processing to this +representation to turn it into a set of [Mel-Frequency Cepstral +Coefficients](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum), or MFCCs +for short. This is also a two-dimensional, one-channel representation so it can +be treated like an image too. If you're targeting general sounds rather than +speech you may find you can skip this step and operate directly on the +spectrograms. + +The image that's produced by these processing steps is then fed into a +multi-layer convolutional neural network, with a fully-connected layer followed +by a softmax at the end. You can see the definition of this portion in +[tensorflow/examples/speech_commands/models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py). + +## Streaming Accuracy + +Most audio recognition applications need to run on a continuous stream of audio, +rather than on individual clips. A typical way to use a model in this +environment is to apply it repeatedly at different offsets in time and average +the results over a short window to produce a smoothed prediction. If you think +of the input as an image, it's continuously scrolling along the time axis. The +words we want to recognize can start at any time, so we need to take a series of +snapshots to have a chance of having an alignment that captures most of the +utterance in the time window we feed into the model. If we sample at a high +enough rate, then we have a good chance of capturing the word in multiple +windows, so averaging the results improves the overall confidence of the +prediction. + +For an example of how you can use your model on streaming data, you can look at +[test_streaming_accuracy.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/). +This uses the +[RecognizeCommands](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/recognize_commands.h) +class to run through a long-form input audio, try to spot words, and compare +those predictions against a ground truth list of labels and times. This makes it +a good example of applying a model to a stream of audio signals over time. + +You'll need a long audio file to test it against, along with labels showing +where each word was spoken. If you don't want to record one yourself, you can +generate some synthetic test data using the `generate_streaming_test_wav` +utility. By default this will create a ten minute .wav file with words roughly +every three seconds, and a text file containing the ground truth of when each +word was spoken. These words are pulled from the test portion of your current +dataset, mixed in with background noise. To run it, use: + +``` +bazel run tensorflow/examples/speech_commands:generate_streaming_test_wav +``` + +This will save a .wav file to `/tmp/speech_commands_train/streaming_test.wav`, +and a text file listing the labels to +`/tmp/speech_commands_train/streaming_test_labels.txt`. You can then run +accuracy testing with: + +``` +bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \ +--graph=/tmp/my_frozen_graph.pb \ +--labels=/tmp/speech_commands_train/conv_labels.txt \ +--wav=/tmp/speech_commands_train/streaming_test.wav \ +--ground_truth=/tmp/speech_commands_train/streaming_test_labels.txt \ +--verbose +``` + +This will output information about the number of words correctly matched, how +many were given the wrong labels, and how many times the model triggered when +there was no real word spoken. There are various parameters that control how the +signal averaging works, including `--average_window_ms` which sets the length of +time to average results over, `--sample_stride_ms` which is the time between +applications of the model, `--suppression_ms` which stops subsequent word +detections from triggering for a certain time after an initial one is found, and +`--detection_threshold`, which controls how high the average score must be +before it's considered a solid result. + +You'll see that the streaming accuracy outputs three numbers, rather than just +the one metric used in training. This is because different applications have +varying requirements, with some being able to tolerate frequent incorrect +results as long as real words are found (high recall), while others very focused +on ensuring the predicted labels are highly likely to be correct even if some +aren't detected (high precision). The numbers from the tool give you an idea of +how your model will perform in an application, and you can try tweaking the +signal averaging parameters to tune it to give the kind of performance you want. +To understand what the right parameters are for your application, you can look +at generating an [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) +to help you understand the tradeoffs. + +## RecognizeCommands + +The streaming accuracy tool uses a simple decoder contained in a small +C++ class called +[RecognizeCommands](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/recognize_commands.h). +This class is fed the output of running the TensorFlow model over time, it +averages the signals, and returns information about a label when it has enough +evidence to think that a recognized word has been found. The implementation is +fairly small, just keeping track of the last few predictions and averaging them, +so it's easy to port to other platforms and languages as needed. For example, +it's convenient to do something similar at the Java level on Android, or Python +on the Raspberry Pi. As long as these implementations share the same logic, you +can tune the parameters that control the averaging using the streaming test +tool, and then transfer them over to your application to get similar results. + +## Advanced Training + +The defaults for the training script are designed to produce good end to end +results in a comparatively small file, but there are a lot of options you can +change to customize the results for your own requirements. + +### Custom Training Data + +By default the script will download the [Speech Commands +dataset](https://download.tensorflow.org/data/speech_commands_v0.01.tgz), but +you can also supply your own training data. To train on your own data, you +should make sure that you have at least several hundred recordings of each sound +you would like to recognize, and arrange them into folders by class. For +example, if you were trying to recognize dog barks from cat miaows, you would +create a root folder called `animal_sounds`, and then within that two +sub-folders called `bark` and `miaow`. You would then organize your audio files +into the appropriate folders. + +To point the script to your new audio files, you'll need to set `--data_url=` to +disable downloading of the Speech Commands dataset, and +`--data_dir=/your/data/folder/` to find the files you've just created. + +The files themselves should be 16-bit little-endian PCM-encoded WAVE format. The +sample rate defaults to 16,000, but as long as all your audio is consistently +the same rate (the script doesn't support resampling) you can change this with +the `--sample_rate` argument. The clips should also all be roughly the same +duration. The default expected duration is one second, but you can set this with +the `--clip_duration_ms` flag. If you have clips with variable amounts of +silence at the start, you can look at word alignment tools to standardize them +([here's a quick and dirty approach you can use +too](https://petewarden.com/2017/07/17/a-quick-hack-to-align-single-word-audio-recordings/)). + +One issue to watch out for is that you may have very similar repetitions of the +same sounds in your dataset, and these can give misleading metrics if they're +spread across your training, validation, and test sets. For example, the Speech +Commands set has people repeating the same word multiple times. Each one of +those repetitions is likely to be pretty close to the others, so if training was +overfitting and memorizing one, it could perform unrealistically well when it +saw a very similar copy in the test set. To avoid this danger, Speech Commands +trys to ensure that all clips featuring the same word spoken by a single person +are put into the same partition. Clips are assigned to training, test, or +validation sets based on a hash of their filename, to ensure that the +assignments remain steady even as new clips are added and avoid any training +samples migrating into the other sets. To make sure that all a given speaker's +words are in the same bucket, [the hashing +function](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/input_data.py) +ignores anything in a filename after '_nohash_' when calculating the +assignments. This means that if you have file names like `pete_nohash_0.wav` and +`pete_nohash_1.wav`, they're guaranteed to be in the same set. + +### Unknown Class + +It's likely that your application will hear sounds that aren't in your training +set, and you'll want the model to indicate that it doesn't recognize the noise +in those cases. To help the network learn what sounds to ignore, you need to +provide some clips of audio that are neither of your classes. To do this, you'd +create `quack`, `oink`, and `moo` subfolders and populate them with noises from +other animals your users might encounter. The `--wanted_words` argument to the +script defines which classes you care about, all the others mentioned in +subfolder names will be used to populate an `_unknown_` class during training. +The Speech Commands dataset has twenty words in its unknown classes, including +the digits zero through nine and random names like "Sheila". + +By default 10% of the training examples are picked from the unknown classes, but +you can control this with the `--unknown_percentage` flag. Increasing this will +make the model less likely to mistake unknown words for wanted ones, but making +it too large can backfire as the model might decide it's safest to categorize +all words as unknown! + +### Background Noise + +Real applications have to recognize audio even when there are other irrelevant +sounds happening in the environment. To build a model that's robust to this kind +of interference, we need to train against recorded audio with similar +properties. The files in the Speech Commands dataset were captured on a variety +of devices by users in many different environments, not in a studio, so that +helps add some realism to the training. To add even more, you can mix in random +segments of environmental audio to the training inputs. In the Speech Commands +set there's a special folder called `_background_noise_` which contains +minute-long WAVE files with white noise and recordings of machinery and everyday +household activity. + +Small snippets of these files are chosen at random and mixed at a low volume +into clips during training. The loudness is also chosen randomly, and controlled +by the `--background_volume` argument as a proportion where 0 is silence, and 1 +is full volume. Not all clips have background added, so the +`--background_frequency` flag controls what proportion have them mixed in. + +Your own application might operate in its own environment with different +background noise patterns than these defaults, so you can supply your own audio +clips in the `_background_noise_` folder. These should be the same sample rate +as your main dataset, but much longer in duration so that a good set of random +segments can be selected from them. + +### Silence + +In most cases the sounds you care about will be intermittent and so it's +important to know when there's no matching audio. To support this, there's a +special `_silence_` label that indicates when the model detects nothing +interesting. Because there's never complete silence in real environments, we +actually have to supply examples with quiet and irrelevant audio. For this, we +reuse the `_background_noise_` folder that's also mixed in to real clips, +pulling short sections of the audio data and feeding those in with the ground +truth class of `_silence_`. By default 10% of the training data is supplied like +this, but the `--silence_percentage` can be used to control the proportion. As +with unknown words, setting this higher can weight the model results in favor of +true positives for silence, at the expense of false negatives for words, but too +large a proportion can cause it to fall into the trap of always guessing +silence. + +### Time Shifting + +Adding in background noise is one way of distorting the training data in a +realistic way to effectively increase the size of the dataset, and so increase +overall accuracy, and time shifting is another. This involves a random offset in +time of the training sample data, so that a small part of the start or end is +cut off and the opposite section is padded with zeroes. This mimics the natural +variations in starting time in the training data, and is controlled with the +`--time_shift_ms` flag, which defaults to 100ms. Increasing this value will +provide more variation, but at the risk of cutting off important parts of the +audio. A related way of augmenting the data with realistic distortions is by +using [time stretching and pitch scaling](https://en.wikipedia.org/wiki/Audio_time_stretching_and_pitch_scaling), +but that's outside the scope of this tutorial. + +## Customizing the Model + +The default model used for this script is pretty large, taking over 800 million +FLOPs for each inference and using 940,000 weight parameters. This runs at +usable speeds on desktop machines or modern phones, but it involves too many +calculations to run at interactive speeds on devices with more limited +resources. To support these use cases, there's an alternative model available, +based on the 'cnn-one-fstride4' architecture described in the [Convolutional +Neural Networks for Small-footprint Keyword Spotting +paper](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf). +The number of weight parameters is about the same, but it only needs 11 million +FLOPs to run one prediction, making it much faster. + +To use this model, you can specify `--model_architecture=low_latency_conv` on +the command line. You'll also need to update the training rates and the number +of steps, so the full command will look like: + +``` +python tensorflow/examples/speech_commands/train \ +--model_architecture=low_latency_conv \ +--how_many_training_steps=20000,6000 \ +--learning_rate=0.01,0.001 +``` + +This asks the script to train with a learning rate of 0.01 for 20,000 steps, and +then do a fine-tuning pass of 6,000 steps with a 10x smaller rate. + +If you want to experiment with customizing models, a good place to start is by +tweaking the spectrogram creation parameters. This has the effect of altering +the size of the input image to the model, and the creation code in +[models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py) +will adjust the number of computations and weights automatically to fit with +different dimensions. If you make the input smaller, the model will need fewer +computations to process it, so it can be a great way to trade off some accuracy +for improved latency. The `--window_stride_ms` controls how far apart each +frequency analysis sample is from the previous. If you increase this value, then +fewer samples will be taken for a given duration, and the time axis of the input +will shrink. The `--dct_coefficient_count` flag controls how many buckets are +used for the frequency counting, so reducing this will shrink the input in the +other dimension. The `--window_size_ms` argument doesn't affect the size, but +does control how wide the area used to calculate the frequencies is for each +sample. Reducing the duration of the training samples, controlled by +`--clip_duration_ms`, can also help if the sounds you're looking for are short, +since that also reduces the time dimension of the input. You'll need to make +sure that all your training data contains the right audio in the initial portion +of the clip though. + +If you have an entirely different model in mind for your problem, you may find +that you can plug it into +[models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py) +and have the rest of the script handle all of the preprocessing and training +mechanics. You would add a new clause to `create_model`, looking for the name of +your architecture and then calling a model creation function. This function is +given the size of the spectrogram input, along with other model information, and +is expected to create TensorFlow ops to read that in and produce an output +prediction vector, and a placeholder to control the dropout rate. The rest of +the script will handle integrating this model into a larger graph doing the +input calculations and applying softmax and a loss function to train it. + +One common problem when you're adjusting models and training hyper-parameters is +that not-a-number values can creep in, thanks to numerical precision issues. In +general you can solve these by reducing the magnitude of things like learning +rates and weight initialization functions, but if they're persistent you can +enable the `--check_nans` flag to track down the source of the errors. This will +insert check ops between most regular operations in TensorFlow, and abort the +training process with a useful error message when they're encountered. diff --git a/tensorflow/examples/speech_commands/BUILD b/tensorflow/examples/speech_commands/BUILD new file mode 100644 index 00000000000..4307b9471d4 --- /dev/null +++ b/tensorflow/examples/speech_commands/BUILD @@ -0,0 +1,258 @@ +package( + default_visibility = [ + "//visibility:public", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files([ + "LICENSE", +]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "models", + srcs = [ + "models.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +tf_py_test( + name = "models_test", + size = "small", + srcs = ["models_test.py"], + additional_deps = [ + ":models", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "input_data", + srcs = [ + "input_data.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +tf_py_test( + name = "input_data_test", + size = "small", + srcs = ["input_data_test.py"], + additional_deps = [ + ":input_data", + "//tensorflow/python:client_testlib", + ], +) + +py_binary( + name = "train", + srcs = [ + "train.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":input_data", + ":models", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_binary( + name = "freeze", + srcs = [ + "freeze.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":input_data", + ":models", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +tf_py_test( + name = "freeze_test", + size = "small", + srcs = ["freeze_test.py"], + additional_deps = [ + ":freeze", + "//tensorflow/python:client_testlib", + ], +) + +py_binary( + name = "generate_streaming_test_wav", + srcs = [ + "generate_streaming_test_wav.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":input_data", + ":models", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +tf_py_test( + name = "generate_streaming_test_wav_test", + size = "small", + srcs = ["generate_streaming_test_wav_test.py"], + additional_deps = [ + ":generate_streaming_test_wav", + "//tensorflow/python:client_testlib", + ], +) + +cc_binary( + name = "label_wav_cc", + srcs = [ + "label_wav.cc", + ], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +py_binary( + name = "label_wav", + srcs = [ + "label_wav.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +tf_py_test( + name = "label_wav_test", + size = "medium", + srcs = ["label_wav_test.py"], + additional_deps = [ + ":label_wav", + "//tensorflow/python:client_testlib", + ], +) + +cc_library( + name = "recognize_commands", + srcs = [ + "recognize_commands.cc", + ], + hdrs = [ + "recognize_commands.h", + ], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +cc_test( + name = "recognize_commands_test", + size = "medium", + srcs = [ + "recognize_commands_test.cc", + ], + deps = [ + ":recognize_commands", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "accuracy_utils", + srcs = [ + "accuracy_utils.cc", + ], + hdrs = [ + "accuracy_utils.h", + ], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +cc_test( + name = "accuracy_utils_test", + size = "medium", + srcs = [ + "accuracy_utils_test.cc", + ], + deps = [ + ":accuracy_utils", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_binary( + name = "test_streaming_accuracy", + srcs = [ + "test_streaming_accuracy.cc", + ], + deps = [ + ":accuracy_utils", + ":recognize_commands", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/examples/speech_commands/README.md b/tensorflow/examples/speech_commands/README.md new file mode 100644 index 00000000000..3b782101292 --- /dev/null +++ b/tensorflow/examples/speech_commands/README.md @@ -0,0 +1,4 @@ +# Speech Commands Example + +This is a basic speech recognition example. For more information, see the +tutorial at http://tensorflow.org/tutorials/audio_recognition. diff --git a/tensorflow/examples/speech_commands/accuracy_utils.cc b/tensorflow/examples/speech_commands/accuracy_utils.cc new file mode 100644 index 00000000000..b9d3dd66991 --- /dev/null +++ b/tensorflow/examples/speech_commands/accuracy_utils.cc @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/speech_commands/accuracy_utils.h" + +#include +#include +#include + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +Status ReadGroundTruthFile(const string& file_name, + std::vector>* result) { + std::ifstream file(file_name); + if (!file) { + return tensorflow::errors::NotFound("Ground truth file '", file_name, + "' not found."); + } + result->clear(); + string line; + while (std::getline(file, line)) { + std::vector pieces = tensorflow::str_util::Split(line, ','); + if (pieces.size() != 2) { + continue; + } + float timestamp; + if (!tensorflow::strings::safe_strtof(pieces[1].c_str(), ×tamp)) { + return tensorflow::errors::InvalidArgument( + "Wrong number format at line: ", line); + } + string label = pieces[0]; + auto timestamp_int64 = static_cast(timestamp); + result->push_back({label, timestamp_int64}); + } + std::sort(result->begin(), result->end(), + [](const std::pair& left, + const std::pair& right) { + return left.second < right.second; + }); + return Status::OK(); +} + +void CalculateAccuracyStats( + const std::vector>& ground_truth_list, + const std::vector>& found_words, + int64 up_to_time_ms, int64 time_tolerance_ms, + StreamingAccuracyStats* stats) { + int64 latest_possible_time; + if (up_to_time_ms == -1) { + latest_possible_time = std::numeric_limits::max(); + } else { + latest_possible_time = up_to_time_ms + time_tolerance_ms; + } + stats->how_many_ground_truth_words = 0; + for (const std::pair& ground_truth : ground_truth_list) { + const int64 ground_truth_time = ground_truth.second; + if (ground_truth_time > latest_possible_time) { + break; + } + ++stats->how_many_ground_truth_words; + } + + stats->how_many_false_positives = 0; + stats->how_many_correct_words = 0; + stats->how_many_wrong_words = 0; + std::unordered_set has_ground_truth_been_matched; + for (const std::pair& found_word : found_words) { + const string& found_label = found_word.first; + const int64 found_time = found_word.second; + const int64 earliest_time = found_time - time_tolerance_ms; + const int64 latest_time = found_time + time_tolerance_ms; + bool has_match_been_found = false; + for (const std::pair& ground_truth : ground_truth_list) { + const int64 ground_truth_time = ground_truth.second; + if ((ground_truth_time > latest_time) || + (ground_truth_time > latest_possible_time)) { + break; + } + if (ground_truth_time < earliest_time) { + continue; + } + const string& ground_truth_label = ground_truth.first; + if ((ground_truth_label == found_label) && + (has_ground_truth_been_matched.count(ground_truth_time) == 0)) { + ++stats->how_many_correct_words; + } else { + ++stats->how_many_wrong_words; + } + has_ground_truth_been_matched.insert(ground_truth_time); + has_match_been_found = true; + break; + } + if (!has_match_been_found) { + ++stats->how_many_false_positives; + } + } + stats->how_many_ground_truth_matched = has_ground_truth_been_matched.size(); +} + +void PrintAccuracyStats(const StreamingAccuracyStats& stats) { + if (stats.how_many_ground_truth_words == 0) { + LOG(INFO) << "No ground truth yet, " << stats.how_many_false_positives + << " false positives"; + } else { + float any_match_percentage = + (stats.how_many_ground_truth_matched * 100.0f) / + stats.how_many_ground_truth_words; + float correct_match_percentage = (stats.how_many_correct_words * 100.0f) / + stats.how_many_ground_truth_words; + float wrong_match_percentage = (stats.how_many_wrong_words * 100.0f) / + stats.how_many_ground_truth_words; + float false_positive_percentage = + (stats.how_many_false_positives * 100.0f) / + stats.how_many_ground_truth_words; + + LOG(INFO) << std::setprecision(1) << std::fixed << any_match_percentage + << "% matched, " << correct_match_percentage << "% correctly, " + << wrong_match_percentage << "% wrongly, " + << false_positive_percentage << "% false positives "; + } +} + +} // namespace tensorflow diff --git a/tensorflow/examples/speech_commands/accuracy_utils.h b/tensorflow/examples/speech_commands/accuracy_utils.h new file mode 100644 index 00000000000..8d918cb64b0 --- /dev/null +++ b/tensorflow/examples/speech_commands/accuracy_utils.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +struct StreamingAccuracyStats { + StreamingAccuracyStats() + : how_many_ground_truth_words(0), + how_many_ground_truth_matched(0), + how_many_false_positives(0), + how_many_correct_words(0), + how_many_wrong_words(0) {} + int32 how_many_ground_truth_words; + int32 how_many_ground_truth_matched; + int32 how_many_false_positives; + int32 how_many_correct_words; + int32 how_many_wrong_words; +}; + +// Takes a file name, and loads a list of expected word labels and times from +// it, as comma-separated variables. +Status ReadGroundTruthFile(const string& file_name, + std::vector>* result); + +// Given ground truth labels and corresponding predictions found by a model, +// figure out how many were correct. Takes a time limit, so that only +// predictions up to a point in time are considered, in case we're evaluating +// accuracy when the model has only been run on part of the stream. +void CalculateAccuracyStats( + const std::vector>& ground_truth_list, + const std::vector>& found_words, + int64 up_to_time_ms, int64 time_tolerance_ms, + StreamingAccuracyStats* stats); + +// Writes a human-readable description of the statistics to stdout. +void PrintAccuracyStats(const StreamingAccuracyStats& stats); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ diff --git a/tensorflow/examples/speech_commands/accuracy_utils_test.cc b/tensorflow/examples/speech_commands/accuracy_utils_test.cc new file mode 100644 index 00000000000..47653ddf037 --- /dev/null +++ b/tensorflow/examples/speech_commands/accuracy_utils_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/speech_commands/accuracy_utils.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(AccuracyUtilsTest, ReadGroundTruthFile) { + string file_name = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), + "ground_truth.txt"); + string file_data = "a,10\nb,12\n"; + TF_ASSERT_OK(WriteStringToFile(Env::Default(), file_name, file_data)); + + std::vector> ground_truth; + TF_ASSERT_OK(ReadGroundTruthFile(file_name, &ground_truth)); + ASSERT_EQ(2, ground_truth.size()); + EXPECT_EQ("a", ground_truth[0].first); + EXPECT_EQ(10, ground_truth[0].second); + EXPECT_EQ("b", ground_truth[1].first); + EXPECT_EQ(12, ground_truth[1].second); +} + +TEST(AccuracyUtilsTest, CalculateAccuracyStats) { + StreamingAccuracyStats stats; + CalculateAccuracyStats({{"a", 1000}, {"b", 9000}}, + {{"a", 1200}, {"b", 5000}, {"a", 8700}}, 10000, 500, + &stats); + EXPECT_EQ(2, stats.how_many_ground_truth_words); + EXPECT_EQ(2, stats.how_many_ground_truth_matched); + EXPECT_EQ(1, stats.how_many_false_positives); + EXPECT_EQ(1, stats.how_many_correct_words); + EXPECT_EQ(1, stats.how_many_wrong_words); +} + +TEST(AccuracyUtilsTest, PrintAccuracyStats) { + StreamingAccuracyStats stats; + PrintAccuracyStats(stats); +} + +} // namespace tensorflow diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py new file mode 100644 index 00000000000..381f3d029e5 --- /dev/null +++ b/tensorflow/examples/speech_commands/freeze.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Converts a trained checkpoint into a frozen model for mobile inference. + +Once you've trained a model using the `train.py` script, you can use this tool +to convert it into a binary GraphDef file that can be loaded into the Android, +iOS, or Raspberry Pi example code. Here's an example of how to run it: + +bazel run tensorflow/examples/speech_commands/freeze -- \ +--sample_rate=16000 --dct_coefficient_count=40 --window_size_ms=20 \ +--window_stride_ms=10 --clip_duration_ms=1000 \ +--model_architecture=conv \ +--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-1300 \ +--output_file=/tmp/my_frozen_graph.pb + +One thing to watch out for is that you need to pass in the same arguments for +`sample_rate` and other command line variables here as you did for the training +script. + +The resulting graph has an input for WAV-encoded data named 'wav_data', one for +raw PCM data (as floats in the range -1.0 to 1.0) called 'decoded_sample_data', +and the output is called 'labels_softmax'. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os.path +import sys + +import tensorflow as tf + +from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio +import tensorflow.examples.speech_commands.input_data as input_data +import tensorflow.examples.speech_commands.models as models +from tensorflow.python.framework import graph_util + +FLAGS = None + + +def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, + window_size_ms, window_stride_ms, + dct_coefficient_count, model_architecture): + """Creates an audio model with the nodes needed for inference. + + Uses the supplied arguments to create a model, and inserts the input and + output nodes that are needed to use the graph for inference. + + Args: + wanted_words: Comma-separated list of the words we're trying to recognize. + sample_rate: How many samples per second are in the input audio files. + clip_duration_ms: How many samples to analyze for the audio pattern. + window_size_ms: Time slice duration to estimate frequencies from. + window_stride_ms: How far apart time slices should be. + dct_coefficient_count: Number of frequency bands to analyze. + model_architecture: Name of the kind of model to generate. + """ + + words_list = input_data.prepare_words_list(wanted_words.split(',')) + model_settings = models.prepare_model_settings( + len(words_list), sample_rate, clip_duration_ms, window_size_ms, + window_stride_ms, dct_coefficient_count) + + wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data') + decoded_sample_data = contrib_audio.decode_wav( + wav_data_placeholder, + desired_channels=1, + desired_samples=model_settings['desired_samples'], + name='decoded_sample_data') + spectrogram = contrib_audio.audio_spectrogram( + decoded_sample_data.audio, + window_size=model_settings['window_size_samples'], + stride=model_settings['window_stride_samples'], + magnitude_squared=True) + fingerprint_input = contrib_audio.mfcc( + spectrogram, + decoded_sample_data.sample_rate, + dct_coefficient_count=dct_coefficient_count) + + logits = models.create_model( + fingerprint_input, model_settings, model_architecture, is_training=False) + + # Create an output to use for inference. + tf.nn.softmax(logits, name='labels_softmax') + + +def main(_): + + # Create the model and load its weights. + sess = tf.InteractiveSession() + create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate, + FLAGS.clip_duration_ms, FLAGS.window_size_ms, + FLAGS.window_stride_ms, FLAGS.dct_coefficient_count, + FLAGS.model_architecture) + models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) + + # Turn all the variables into inline constants inside the graph and save it. + frozen_graph_def = graph_util.convert_variables_to_constants( + sess, sess.graph_def, ['labels_softmax']) + tf.train.write_graph( + frozen_graph_def, + os.path.dirname(FLAGS.output_file), + os.path.basename(FLAGS.output_file), + as_text=False) + tf.logging.info('Saved frozen graph to %s', FLAGS.output_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--sample_rate', + type=int, + default=16000, + help='Expected sample rate of the wavs',) + parser.add_argument( + '--clip_duration_ms', + type=int, + default=1000, + help='Expected duration in milliseconds of the wavs',) + parser.add_argument( + '--window_size_ms', + type=float, + default=20.0, + help='How long each spectrogram timeslice is',) + parser.add_argument( + '--window_stride_ms', + type=float, + default=10.0, + help='How long each spectrogram timeslice is',) + parser.add_argument( + '--dct_coefficient_count', + type=int, + default=40, + help='How many bins to use for the MFCC fingerprint',) + parser.add_argument( + '--start_checkpoint', + type=str, + default='', + help='If specified, restore this pretrained model before any training.') + parser.add_argument( + '--model_architecture', + type=str, + default='conv', + help='What model architecture to use') + parser.add_argument( + '--wanted_words', + type=str, + default='yes,no,up,down,left,right,on,off,stop,go', + help='Words to use (others will be added to an unknown label)',) + parser.add_argument( + '--output_file', type=str, help='Where to save the frozen graph.') + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py new file mode 100644 index 00000000000..3386f0f282c --- /dev/null +++ b/tensorflow/examples/speech_commands/freeze_test.py @@ -0,0 +1,38 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for data input for speech commands.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.examples.speech_commands import freeze +from tensorflow.python.platform import test + + +class FreezeTest(test.TestCase): + + def testCreateInferenceGraph(self): + with self.test_session() as sess: + freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 20.0, 10.0, 40, + 'conv') + self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) + self.assertIsNotNone( + sess.graph.get_tensor_by_name('decoded_sample_data:0')) + self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/examples/speech_commands/generate_streaming_test_wav.py b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py new file mode 100644 index 00000000000..a69e4d72c7b --- /dev/null +++ b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py @@ -0,0 +1,281 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Saves out a .wav file with synthesized conversational data and labels. + +The best way to estimate the real-world performance of an audio recognition +model is by running it against a continuous stream of data, the way that it +would be used in an application. Training evaluations are only run against +discrete individual samples, so the results aren't as realistic. + +To make it easy to run evaluations against audio streams, this script uses +samples from the testing partition of the data set, mixes them in at random +positions together with background noise, and saves out the result as one long +audio file. + +Here's an example of generating a test file: + +bazel run tensorflow/examples/speech_commands:generate_streaming_test_wav -- \ +--data_dir=/tmp/my_wavs --background_dir=/tmp/my_backgrounds \ +--background_volume=0.1 --test_duration_seconds=600 \ +--output_audio_file=/tmp/streaming_test.wav \ +--output_labels_file=/tmp/streaming_test_labels.txt + +Once you've created a streaming audio file, you can then use the +test_streaming_accuracy tool to calculate accuracy metrics for a model. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import math +import sys + +import numpy as np +import tensorflow as tf + +import tensorflow.examples.speech_commands.input_data as input_data +import tensorflow.examples.speech_commands.models as models + +FLAGS = None + + +def mix_in_audio_sample(track_data, track_offset, sample_data, sample_offset, + clip_duration, sample_volume, ramp_in, ramp_out): + """Mixes the sample data into the main track at the specified offset. + + Args: + track_data: Numpy array holding main audio data. Modified in-place. + track_offset: Where to mix the sample into the main track. + sample_data: Numpy array of audio data to mix into the main track. + sample_offset: Where to start in the audio sample. + clip_duration: How long the sample segment is. + sample_volume: Loudness to mix the sample in at. + ramp_in: Length in samples of volume increase stage. + ramp_out: Length in samples of volume decrease stage. + """ + ramp_out_index = clip_duration - ramp_out + track_end = min(track_offset + clip_duration, track_data.shape[0]) + track_end = min(track_end, + track_offset + (sample_data.shape[0] - sample_offset)) + sample_range = track_end - track_offset + for i in range(sample_range): + if i < ramp_in: + envelope_scale = i / ramp_in + elif i > ramp_out_index: + envelope_scale = (clip_duration - i) / ramp_out + else: + envelope_scale = 1 + sample_input = sample_data[sample_offset + i] + track_data[track_offset + + i] += sample_input * envelope_scale * sample_volume + + +def main(_): + audio_processor = input_data.AudioProcessor( + '', FLAGS.data_dir, FLAGS.silence_percentage, 10, + FLAGS.wanted_words.split(','), FLAGS.validation_percentage, + FLAGS.testing_percentage) + words_list = input_data.prepare_words_list(FLAGS.wanted_words.split(',')) + model_settings = models.prepare_model_settings( + len(words_list), FLAGS.sample_rate, FLAGS.clip_duration_ms, + FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count) + + output_audio_sample_count = FLAGS.sample_rate * FLAGS.test_duration_seconds + output_audio = np.zeros((output_audio_sample_count,), dtype=np.float32) + + # Set up background audio. + background_crossover_ms = 500 + background_segment_duration_ms = ( + FLAGS.clip_duration_ms + background_crossover_ms) + background_segment_duration_samples = int( + (background_segment_duration_ms * FLAGS.sample_rate) / 1000) + background_segment_stride_samples = int( + (FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000) + background_ramp_samples = int( + ((background_crossover_ms / 2) * FLAGS.sample_rate) / 1000) + + # Mix the background audio into the main track. + how_many_backgrounds = int( + math.ceil(output_audio_sample_count / background_segment_stride_samples)) + for i in range(how_many_backgrounds): + output_offset = int(i * background_segment_stride_samples) + background_index = np.random.randint(len(audio_processor.background_data)) + background_samples = audio_processor.background_data[background_index] + background_offset = np.random.randint( + 0, len(background_samples) - model_settings['desired_samples']) + background_volume = np.random.uniform(0, FLAGS.background_volume) + mix_in_audio_sample(output_audio, output_offset, background_samples, + background_offset, background_segment_duration_samples, + background_volume, background_ramp_samples, + background_ramp_samples) + + # Mix the words into the main track, noting their labels and positions. + output_labels = [] + word_stride_ms = FLAGS.clip_duration_ms + FLAGS.word_gap_ms + word_stride_samples = int((word_stride_ms * FLAGS.sample_rate) / 1000) + clip_duration_samples = int( + (FLAGS.clip_duration_ms * FLAGS.sample_rate) / 1000) + word_gap_samples = int((FLAGS.word_gap_ms * FLAGS.sample_rate) / 1000) + how_many_words = int( + math.floor(output_audio_sample_count / word_stride_samples)) + all_test_data, all_test_labels = audio_processor.get_unprocessed_data( + -1, model_settings, 'testing') + for i in range(how_many_words): + output_offset = ( + int(i * word_stride_samples) + np.random.randint(word_gap_samples)) + output_offset_ms = (output_offset * 1000) / FLAGS.sample_rate + is_unknown = np.random.randint(100) < FLAGS.unknown_percentage + if is_unknown: + wanted_label = input_data.UNKNOWN_WORD_LABEL + else: + wanted_label = words_list[2 + np.random.randint(len(words_list) - 2)] + test_data_start = np.random.randint(len(all_test_data)) + found_sample_data = None + index_lookup = np.arange(len(all_test_data), dtype=np.int32) + np.random.shuffle(index_lookup) + for test_data_offset in range(len(all_test_data)): + test_data_index = index_lookup[( + test_data_start + test_data_offset) % len(all_test_data)] + current_label = all_test_labels[test_data_index] + if current_label == wanted_label: + found_sample_data = all_test_data[test_data_index] + break + mix_in_audio_sample(output_audio, output_offset, found_sample_data, 0, + clip_duration_samples, 1.0, 500, 500) + output_labels.append({'label': wanted_label, 'time': output_offset_ms}) + + input_data.save_wav_file(FLAGS.output_audio_file, output_audio, + FLAGS.sample_rate) + tf.logging.info('Saved streaming test wav to %s', FLAGS.output_audio_file) + + with open(FLAGS.output_labels_file, 'w') as f: + for output_label in output_labels: + f.write('%s, %f\n' % (output_label['label'], output_label['time'])) + tf.logging.info('Saved streaming test labels to %s', FLAGS.output_labels_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--data_url', + type=str, + # pylint: disable=line-too-long + default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz', + # pylint: enable=line-too-long + help='Location of speech training data') + parser.add_argument( + '--data_dir', + type=str, + default='/tmp/speech_dataset', + help="""\ + Where to download the speech training data to. + """) + parser.add_argument( + '--background_dir', + type=str, + default='', + help="""\ + Path to a directory of .wav files to mix in as background noise during training. + """) + parser.add_argument( + '--background_volume', + type=float, + default=0.1, + help="""\ + How loud the background noise should be, between 0 and 1. + """) + parser.add_argument( + '--background_frequency', + type=float, + default=0.8, + help="""\ + How many of the training samples have background noise mixed in. + """) + parser.add_argument( + '--silence_percentage', + type=float, + default=10.0, + help="""\ + How much of the training data should be silence. + """) + parser.add_argument( + '--testing_percentage', + type=int, + default=10, + help='What percentage of wavs to use as a test set.') + parser.add_argument( + '--validation_percentage', + type=int, + default=10, + help='What percentage of wavs to use as a validation set.') + parser.add_argument( + '--sample_rate', + type=int, + default=16000, + help='Expected sample rate of the wavs.',) + parser.add_argument( + '--clip_duration_ms', + type=int, + default=1000, + help='Expected duration in milliseconds of the wavs.',) + parser.add_argument( + '--window_size_ms', + type=float, + default=20.0, + help='How long each spectrogram timeslice is',) + parser.add_argument( + '--window_stride_ms', + type=float, + default=10.0, + help='How long each spectrogram timeslice is',) + parser.add_argument( + '--dct_coefficient_count', + type=int, + default=40, + help='How many bins to use for the MFCC fingerprint',) + parser.add_argument( + '--wanted_words', + type=str, + default='yes,no,up,down,left,right,on,off,stop,go', + help='Words to use (others will be added to an unknown label)',) + parser.add_argument( + '--output_audio_file', + type=str, + default='/tmp/speech_commands_train/streaming_test.wav', + help='File to save the generated test audio to.') + parser.add_argument( + '--output_labels_file', + type=str, + default='/tmp/speech_commands_train/streaming_test_labels.txt', + help='File to save the generated test labels to.') + parser.add_argument( + '--test_duration_seconds', + type=int, + default=600, + help='How long the generated test audio file should be.',) + parser.add_argument( + '--word_gap_ms', + type=int, + default=2000, + help='How long the average gap should be between words.',) + parser.add_argument( + '--unknown_percentage', + type=int, + default=30, + help='What percentage of words should be unknown.') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/speech_commands/generate_streaming_test_wav_test.py b/tensorflow/examples/speech_commands/generate_streaming_test_wav_test.py new file mode 100644 index 00000000000..63d93f4534c --- /dev/null +++ b/tensorflow/examples/speech_commands/generate_streaming_test_wav_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for test file generation for speech commands.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.examples.speech_commands import generate_streaming_test_wav +from tensorflow.python.platform import test + + +class GenerateStreamingTestWavTest(test.TestCase): + + def testMixInAudioSample(self): + track_data = np.zeros([10000]) + sample_data = np.ones([1000]) + generate_streaming_test_wav.mix_in_audio_sample( + track_data, 2000, sample_data, 0, 1000, 1.0, 100, 100) + self.assertNear(1.0, track_data[2500], 0.0001) + self.assertNear(0.0, track_data[3500], 0.0001) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py new file mode 100644 index 00000000000..6d75fbb92b2 --- /dev/null +++ b/tensorflow/examples/speech_commands/input_data.py @@ -0,0 +1,532 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model definitions for simple speech recognition. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import math +import os.path +import random +import re +import sys +import tarfile + +import numpy as np +from six.moves import urllib +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio +from tensorflow.python.ops import io_ops +from tensorflow.python.platform import gfile +from tensorflow.python.util import compat + +MAX_NUM_WAVS_PER_CLASS = 2**27 - 1 # ~134M +SILENCE_LABEL = '_silence_' +SILENCE_INDEX = 0 +UNKNOWN_WORD_LABEL = '_unknown_' +UNKNOWN_WORD_INDEX = 1 +BACKGROUND_NOISE_DIR_NAME = '_background_noise_' +RANDOM_SEED = 59185 + + +def prepare_words_list(wanted_words): + """Prepends common tokens to the custom word list. + + Args: + wanted_words: List of strings containing the custom words. + + Returns: + List with the standard silence and unknown tokens added. + """ + return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words + + +def which_set(filename, validation_percentage, testing_percentage): + """Determines which data partition the file should belong to. + + We want to keep files in the same training, validation, or testing sets even + if new ones are added over time. This makes it less likely that testing + samples will accidentally be reused in training when long runs are restarted + for example. To keep this stability, a hash of the filename is taken and used + to determine which set it should belong to. This determination only depends on + the name and the set proportions, so it won't change as other files are added. + + It's also useful to associate particular files as related (for example words + spoken by the same person), so anything after '_nohash_' in a filename is + ignored for set determination. This ensures that 'bobby_nohash_0.wav' and + 'bobby_nohash_1.wav' are always in the same set, for example. + + Args: + filename: File path of the data sample. + validation_percentage: How much of the data set to use for validation. + testing_percentage: How much of the data set to use for testing. + + Returns: + String, one of 'training', 'validation', or 'testing'. + """ + base_name = os.path.basename(filename) + # We want to ignore anything after '_nohash_' in the file name when + # deciding which set to put a wav in, so the data set creator has a way of + # grouping wavs that are close variations of each other. + hash_name = re.sub(r'_nohash_.*$', '', base_name) + # This looks a bit magical, but we need to decide whether this file should + # go into the training, testing, or validation sets, and we want to keep + # existing files in the same set even if more files are subsequently + # added. + # To do that, we need a stable way of deciding based on just the file name + # itself, so we do a hash of that and then use that to generate a + # probability value that we use to assign it. + hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest() + percentage_hash = ((int(hash_name_hashed, 16) % + (MAX_NUM_WAVS_PER_CLASS + 1)) * + (100.0 / MAX_NUM_WAVS_PER_CLASS)) + if percentage_hash < validation_percentage: + result = 'validation' + elif percentage_hash < (testing_percentage + validation_percentage): + result = 'testing' + else: + result = 'training' + return result + + +def load_wav_file(filename): + """Loads an audio file and returns a float PCM-encoded array of samples. + + Args: + filename: Path to the .wav file to load. + + Returns: + Numpy array holding the sample data as floats between -1.0 and 1.0. + """ + with tf.Session(graph=tf.Graph()) as sess: + wav_filename_placeholder = tf.placeholder(tf.string, []) + wav_loader = io_ops.read_file(wav_filename_placeholder) + wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1) + return sess.run( + wav_decoder, + feed_dict={wav_filename_placeholder: filename}).audio.flatten() + + +def save_wav_file(filename, wav_data, sample_rate): + """Saves audio sample data to a .wav audio file. + + Args: + filename: Path to save the file to. + wav_data: 2D array of float PCM-encoded audio data. + sample_rate: Samples per second to encode in the file. + """ + with tf.Session(graph=tf.Graph()) as sess: + wav_filename_placeholder = tf.placeholder(tf.string, []) + sample_rate_placeholder = tf.placeholder(tf.int32, []) + wav_data_placeholder = tf.placeholder(tf.float32, [None, 1]) + wav_encoder = contrib_audio.encode_wav(wav_data_placeholder, + sample_rate_placeholder) + wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder) + sess.run( + wav_saver, + feed_dict={ + wav_filename_placeholder: filename, + sample_rate_placeholder: sample_rate, + wav_data_placeholder: np.reshape(wav_data, (-1, 1)) + }) + + +class AudioProcessor(object): + """Handles loading, partitioning, and preparing audio training data.""" + + def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage, + wanted_words, validation_percentage, testing_percentage, + model_settings): + self.data_dir = data_dir + self.maybe_download_and_extract_dataset(data_url, data_dir) + self.prepare_data_index(silence_percentage, unknown_percentage, + wanted_words, validation_percentage, + testing_percentage) + self.prepare_background_data() + self.prepare_processing_graph(model_settings) + + def maybe_download_and_extract_dataset(self, data_url, dest_directory): + """Download and extract data set tar file. + + If the data set we're using doesn't already exist, this function + downloads it from the TensorFlow.org website and unpacks it into a + directory. + If the data_url is none, don't download anything and expect the data + directory to contain the correct files already. + + Args: + data_url: Web location of the tar file containing the data set. + dest_directory: File path to extract data to. + """ + if not data_url: + return + if not os.path.exists(dest_directory): + os.makedirs(dest_directory) + filename = data_url.split('/')[-1] + filepath = os.path.join(dest_directory, filename) + if not os.path.exists(filepath): + + def _progress(count, block_size, total_size): + sys.stdout.write( + '\r>> Downloading %s %.1f%%' % + (filename, float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + + try: + filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress) + except: + tf.logging.error('Failed to download URL: %s to folder: %s', data_url, + filepath) + tf.logging.error('Please make sure you have enough free space and' + ' an internet connection') + raise + print() + statinfo = os.stat(filepath) + tf.logging.info('Successfully downloaded %s (%d bytes)', filename, + statinfo.st_size) + tarfile.open(filepath, 'r:gz').extractall(dest_directory) + + def prepare_data_index(self, silence_percentage, unknown_percentage, + wanted_words, validation_percentage, + testing_percentage): + """Prepares a list of the samples organized by set and label. + + The training loop needs a list of all the available data, organized by + which partition it should belong to, and with ground truth labels attached. + This function analyzes the folders below the `data_dir`, figures out the + right + labels for each file based on the name of the subdirectory it belongs to, + and uses a stable hash to assign it to a data set partition. + + Args: + silence_percentage: How much of the resulting data should be background. + unknown_percentage: How much should be audio outside the wanted classes. + wanted_words: Labels of the classes we want to be able to recognize. + validation_percentage: How much of the data set to use for validation. + testing_percentage: How much of the data set to use for testing. + + Returns: + Dictionary containing a list of file information for each set partition, + and a lookup map for each class to determine its numeric index. + + Raises: + Exception: If expected files are not found. + """ + # Make sure the shuffling and picking of unknowns is deterministic. + random.seed(RANDOM_SEED) + wanted_words_index = {} + for index, wanted_word in enumerate(wanted_words): + wanted_words_index[wanted_word] = index + 2 + self.data_index = {'validation': [], 'testing': [], 'training': []} + unknown_index = {'validation': [], 'testing': [], 'training': []} + all_words = {} + # Look through all the subfolders to find audio samples + search_path = os.path.join(self.data_dir, '*', '*.wav') + for wav_path in gfile.Glob(search_path): + word = re.search('.*/([^/]+)/.*.wav', wav_path).group(1).lower() + # Treat the '_background_noise_' folder as a special case, since we expect + # it to contain long audio samples we mix in to improve training. + if word == BACKGROUND_NOISE_DIR_NAME: + continue + all_words[word] = True + set_index = which_set(wav_path, validation_percentage, testing_percentage) + # If it's a known class, store its detail, otherwise add it to the list + # we'll use to train the unknown label. + if word in wanted_words_index: + self.data_index[set_index].append({'label': word, 'file': wav_path}) + else: + unknown_index[set_index].append({'label': word, 'file': wav_path}) + if not all_words: + raise Exception('No .wavs found at ' + search_path) + for index, wanted_word in enumerate(wanted_words): + if wanted_word not in all_words: + raise Exception('Expected to find ' + wanted_word + + ' in labels but only found ' + + ', '.join(all_words.keys())) + # We need an arbitrary file to load as the input for the silence samples. + # It's multiplied by zero later, so the content doesn't matter. + silence_wav_path = self.data_index['training'][0]['file'] + for set_index in ['validation', 'testing', 'training']: + set_size = len(self.data_index[set_index]) + silence_size = int(math.ceil(set_size * silence_percentage / 100)) + for _ in range(silence_size): + self.data_index[set_index].append({ + 'label': SILENCE_LABEL, + 'file': silence_wav_path + }) + # Pick some unknowns to add to each partition of the data set. + random.shuffle(unknown_index[set_index]) + unknown_size = int(math.ceil(set_size * unknown_percentage / 100)) + self.data_index[set_index].extend(unknown_index[set_index][:unknown_size]) + # Make sure the ordering is random. + for set_index in ['validation', 'testing', 'training']: + random.shuffle(self.data_index[set_index]) + # Prepare the rest of the result data structure. + self.words_list = prepare_words_list(wanted_words) + self.word_to_index = {} + for word in all_words: + if word in wanted_words_index: + self.word_to_index[word] = wanted_words_index[word] + else: + self.word_to_index[word] = UNKNOWN_WORD_INDEX + self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX + + def prepare_background_data(self): + """Searches a folder for background noise audio, and loads it into memory. + + It's expected that the background audio samples will be in a subdirectory + named '_background_noise_' inside the 'data_dir' folder, as .wavs that match + the sample rate of the training data, but can be much longer in duration. + + If the '_background_noise_' folder doesn't exist at all, this isn't an + error, it's just taken to mean that no background noise augmentation should + be used. If the folder does exist, but it's empty, that's treated as an + error. + + Returns: + List of raw PCM-encoded audio samples of background noise. + + Raises: + Exception: If files aren't found in the folder. + """ + self.background_data = [] + background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME) + if not os.path.exists(background_dir): + return self.background_data + with tf.Session(graph=tf.Graph()) as sess: + wav_filename_placeholder = tf.placeholder(tf.string, []) + wav_loader = io_ops.read_file(wav_filename_placeholder) + wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1) + search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME, + '*.wav') + for wav_path in gfile.Glob(search_path): + wav_data = sess.run( + wav_decoder, + feed_dict={wav_filename_placeholder: wav_path}).audio.flatten() + self.background_data.append(wav_data) + if not self.background_data: + raise Exception('No background wav files were found in ' + search_path) + + def prepare_processing_graph(self, model_settings): + """Builds a TensorFlow graph to apply the input distortions. + + Creates a graph that loads a WAVE file, decodes it, scales the volume, + shifts it in time, adds in background noise, calculates a spectrogram, and + then builds an MFCC fingerprint from that. + + This must be called with an active TensorFlow session running, and it + creates multiple placeholder inputs, and one output: + + - wav_filename_placeholder_: Filename of the WAV to load. + - foreground_volume_placeholder_: How loud the main clip should be. + - time_shift_padding_placeholder_: Where to pad the clip. + - time_shift_offset_placeholder_: How much to move the clip in time. + - background_data_placeholder_: PCM sample data for background noise. + - background_volume_placeholder_: Loudness of mixed-in background. + - mfcc_: Output 2D fingerprint of processed audio. + + Args: + model_settings: Information about the current model being trained. + """ + desired_samples = model_settings['desired_samples'] + self.wav_filename_placeholder_ = tf.placeholder(tf.string, []) + wav_loader = io_ops.read_file(self.wav_filename_placeholder_) + wav_decoder = contrib_audio.decode_wav( + wav_loader, desired_channels=1, desired_samples=desired_samples) + # Allow the audio sample's volume to be adjusted. + self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, []) + scaled_foreground = tf.multiply(wav_decoder.audio, + self.foreground_volume_placeholder_) + # Shift the sample's start position, and pad any gaps with zeros. + self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2]) + self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2]) + padded_foreground = tf.pad( + scaled_foreground, + self.time_shift_padding_placeholder_, + mode='CONSTANT') + sliced_foreground = tf.slice(padded_foreground, + self.time_shift_offset_placeholder_, + [desired_samples, -1]) + # Mix in background noise. + self.background_data_placeholder_ = tf.placeholder(tf.float32, + [desired_samples, 1]) + self.background_volume_placeholder_ = tf.placeholder(tf.float32, []) + background_mul = tf.multiply(self.background_data_placeholder_, + self.background_volume_placeholder_) + background_add = tf.add(background_mul, sliced_foreground) + background_clamp = tf.clip_by_value(background_add, -1.0, 1.0) + # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio. + spectrogram = contrib_audio.audio_spectrogram( + background_clamp, + window_size=model_settings['window_size_samples'], + stride=model_settings['window_stride_samples'], + magnitude_squared=True) + self.mfcc_ = contrib_audio.mfcc( + spectrogram, + wav_decoder.sample_rate, + dct_coefficient_count=model_settings['dct_coefficient_count']) + + def set_size(self, mode): + """Calculates the number of samples in the dataset partition. + + Args: + mode: Which partition, must be 'training', 'validation', or 'testing'. + + Returns: + Number of samples in the partition. + """ + return len(self.data_index[mode]) + + def get_data(self, how_many, offset, model_settings, background_frequency, + background_volume_range, time_shift, mode, sess): + """Gather samples from the data set, applying transformations as needed. + + When the mode is 'training', a random selection of samples will be returned, + otherwise the first N clips in the partition will be used. This ensures that + validation always uses the same samples, reducing noise in the metrics. + + Args: + how_many: Desired number of samples to return. -1 means the entire + contents of this partition. + offset: Where to start when fetching deterministically. + model_settings: Information about the current model being trained. + background_frequency: How many clips will have background noise, 0.0 to + 1.0. + background_volume_range: How loud the background noise will be. + time_shift: How much to randomly shift the clips by in time. + mode: Which partition to use, must be 'training', 'validation', or + 'testing'. + sess: TensorFlow session that was active when processor was created. + + Returns: + List of sample data for the transformed samples, and list of labels in + one-hot form. + """ + # Pick one of the partitions to choose samples from. + candidates = self.data_index[mode] + if how_many == -1: + sample_count = len(candidates) + else: + sample_count = max(0, min(how_many, len(candidates) - offset)) + # Data and labels will be populated and returned. + data = np.zeros((sample_count, model_settings['fingerprint_size'])) + labels = np.zeros((sample_count, model_settings['label_count'])) + desired_samples = model_settings['desired_samples'] + use_background = self.background_data and (mode == 'training') + pick_deterministically = (mode != 'training') + # Use the processing graph we created earlier to repeatedly to generate the + # final output sample data we'll use in training. + for i in xrange(offset, offset + sample_count): + # Pick which audio sample to use. + if how_many == -1 or pick_deterministically: + sample_index = i + else: + sample_index = np.random.randint(len(candidates)) + sample = candidates[sample_index] + # If we're time shifting, set up the offset for this sample. + if time_shift > 0: + time_shift_amount = np.random.randint(-time_shift, time_shift) + else: + time_shift_amount = 0 + if time_shift_amount > 0: + time_shift_padding = [[time_shift_amount, 0], [0, 0]] + time_shift_offset = [0, 0] + else: + time_shift_padding = [[0, -time_shift_amount], [0, 0]] + time_shift_offset = [-time_shift_amount, 0] + input_dict = { + self.wav_filename_placeholder_: sample['file'], + self.time_shift_padding_placeholder_: time_shift_padding, + self.time_shift_offset_placeholder_: time_shift_offset, + } + # Choose a section of background noise to mix in. + if use_background: + background_index = np.random.randint(len(self.background_data)) + background_samples = self.background_data[background_index] + background_offset = np.random.randint( + 0, len(background_samples) - model_settings['desired_samples']) + background_clipped = background_samples[background_offset:( + background_offset + desired_samples)] + background_reshaped = background_clipped.reshape([desired_samples, 1]) + if np.random.uniform(0, 1) < background_frequency: + background_volume = np.random.uniform(0, background_volume_range) + else: + background_volume = 0 + else: + background_reshaped = np.zeros([desired_samples, 1]) + background_volume = 0 + input_dict[self.background_data_placeholder_] = background_reshaped + input_dict[self.background_volume_placeholder_] = background_volume + # If we want silence, mute out the main sample but leave the background. + if sample['label'] == SILENCE_LABEL: + input_dict[self.foreground_volume_placeholder_] = 0 + else: + input_dict[self.foreground_volume_placeholder_] = 1 + # Run the graph to produce the output audio. + data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten() + label_index = self.word_to_index[sample['label']] + labels[i - offset, label_index] = 1 + return data, labels + + def get_unprocessed_data(self, how_many, model_settings, mode): + """Retrieve sample data for the given partition, with no transformations. + + Args: + how_many: Desired number of samples to return. -1 means the entire + contents of this partition. + model_settings: Information about the current model being trained. + mode: Which partition to use, must be 'training', 'validation', or + 'testing'. + + Returns: + List of sample data for the samples, and list of labels in one-hot form. + """ + candidates = self.data_index[mode] + if how_many == -1: + sample_count = len(candidates) + else: + sample_count = how_many + desired_samples = model_settings['desired_samples'] + words_list = self.words_list + data = np.zeros((sample_count, desired_samples)) + labels = [] + with tf.Session(graph=tf.Graph()) as sess: + wav_filename_placeholder = tf.placeholder(tf.string, []) + wav_loader = io_ops.read_file(wav_filename_placeholder) + wav_decoder = contrib_audio.decode_wav( + wav_loader, desired_channels=1, desired_samples=desired_samples) + foreground_volume_placeholder = tf.placeholder(tf.float32, []) + scaled_foreground = tf.multiply(wav_decoder.audio, + foreground_volume_placeholder) + for i in range(sample_count): + if how_many == -1: + sample_index = i + else: + sample_index = np.random.randint(len(candidates)) + sample = candidates[sample_index] + input_dict = {wav_filename_placeholder: sample['file']} + if sample['label'] == SILENCE_LABEL: + input_dict[foreground_volume_placeholder] = 0 + else: + input_dict[foreground_volume_placeholder] = 1 + data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten() + label_index = self.word_to_index[sample['label']] + labels.append(words_list[label_index]) + return data, labels diff --git a/tensorflow/examples/speech_commands/input_data_test.py b/tensorflow/examples/speech_commands/input_data_test.py new file mode 100644 index 00000000000..13f294d39db --- /dev/null +++ b/tensorflow/examples/speech_commands/input_data_test.py @@ -0,0 +1,212 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for data input for speech commands.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio +from tensorflow.examples.speech_commands import input_data +from tensorflow.python.platform import test + + +class InputDataTest(test.TestCase): + + def _getWavData(self): + with self.test_session() as sess: + sample_data = tf.zeros([1000, 2]) + wav_encoder = contrib_audio.encode_wav(sample_data, 16000) + wav_data = sess.run(wav_encoder) + return wav_data + + def _saveTestWavFile(self, filename, wav_data): + with open(filename, "wb") as f: + f.write(wav_data) + + def _saveWavFolders(self, root_dir, labels, how_many): + wav_data = self._getWavData() + for label in labels: + dir_name = os.path.join(root_dir, label) + os.mkdir(dir_name) + for i in range(how_many): + file_path = os.path.join(dir_name, "some_audio_%d.wav" % i) + self._saveTestWavFile(file_path, wav_data) + + def _model_settings(self): + return { + "desired_samples": 160, + "fingerprint_size": 40, + "label_count": 4, + "window_size_samples": 100, + "window_stride_samples": 100, + "dct_coefficient_count": 40, + } + + def testPrepareWordsList(self): + words_list = ["a", "b"] + self.assertGreater( + len(input_data.prepare_words_list(words_list)), len(words_list)) + + def testWhichSet(self): + self.assertEqual( + input_data.which_set("foo.wav", 10, 10), + input_data.which_set("foo.wav", 10, 10)) + self.assertEqual( + input_data.which_set("foo_nohash_0.wav", 10, 10), + input_data.which_set("foo_nohash_1.wav", 10, 10)) + + def testPrepareDataIndex(self): + tmp_dir = self.get_temp_dir() + self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) + audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], + 10, 10, self._model_settings()) + self.assertLess(0, audio_processor.set_size("training")) + self.assertTrue("training" in audio_processor.data_index) + self.assertTrue("validation" in audio_processor.data_index) + self.assertTrue("testing" in audio_processor.data_index) + self.assertEquals(input_data.UNKNOWN_WORD_INDEX, + audio_processor.word_to_index["c"]) + + def testPrepareDataIndexEmpty(self): + tmp_dir = self.get_temp_dir() + self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0) + with self.assertRaises(Exception) as e: + _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10, + self._model_settings()) + self.assertTrue("No .wavs found" in str(e.exception)) + + def testPrepareDataIndexMissing(self): + tmp_dir = self.get_temp_dir() + self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) + with self.assertRaises(Exception) as e: + _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10, + 10, self._model_settings()) + self.assertTrue("Expected to find" in str(e.exception)) + + def testPrepareBackgroundData(self): + tmp_dir = self.get_temp_dir() + background_dir = os.path.join(tmp_dir, "_background_noise_") + os.mkdir(background_dir) + wav_data = self._getWavData() + for i in range(10): + file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) + self._saveTestWavFile(file_path, wav_data) + self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) + audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], + 10, 10, self._model_settings()) + self.assertEqual(10, len(audio_processor.background_data)) + + def testLoadWavFile(self): + tmp_dir = self.get_temp_dir() + file_path = os.path.join(tmp_dir, "load_test.wav") + wav_data = self._getWavData() + self._saveTestWavFile(file_path, wav_data) + sample_data = input_data.load_wav_file(file_path) + self.assertIsNotNone(sample_data) + + def testSaveWavFile(self): + tmp_dir = self.get_temp_dir() + file_path = os.path.join(tmp_dir, "load_test.wav") + save_data = np.zeros([16000, 1]) + input_data.save_wav_file(file_path, save_data, 16000) + loaded_data = input_data.load_wav_file(file_path) + self.assertIsNotNone(loaded_data) + self.assertEqual(16000, len(loaded_data)) + + def testPrepareProcessingGraph(self): + tmp_dir = self.get_temp_dir() + wav_dir = os.path.join(tmp_dir, "wavs") + os.mkdir(wav_dir) + self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) + background_dir = os.path.join(wav_dir, "_background_noise_") + os.mkdir(background_dir) + wav_data = self._getWavData() + for i in range(10): + file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) + self._saveTestWavFile(file_path, wav_data) + model_settings = { + "desired_samples": 160, + "fingerprint_size": 40, + "label_count": 4, + "window_size_samples": 100, + "window_stride_samples": 100, + "dct_coefficient_count": 40, + } + audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], + 10, 10, model_settings) + self.assertIsNotNone(audio_processor.wav_filename_placeholder_) + self.assertIsNotNone(audio_processor.foreground_volume_placeholder_) + self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_) + self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_) + self.assertIsNotNone(audio_processor.background_data_placeholder_) + self.assertIsNotNone(audio_processor.background_volume_placeholder_) + self.assertIsNotNone(audio_processor.mfcc_) + + def testGetData(self): + tmp_dir = self.get_temp_dir() + wav_dir = os.path.join(tmp_dir, "wavs") + os.mkdir(wav_dir) + self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) + background_dir = os.path.join(wav_dir, "_background_noise_") + os.mkdir(background_dir) + wav_data = self._getWavData() + for i in range(10): + file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) + self._saveTestWavFile(file_path, wav_data) + model_settings = { + "desired_samples": 160, + "fingerprint_size": 40, + "label_count": 4, + "window_size_samples": 100, + "window_stride_samples": 100, + "dct_coefficient_count": 40, + } + audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], + 10, 10, model_settings) + with self.test_session() as sess: + result_data, result_labels = audio_processor.get_data( + 10, 0, model_settings, 0.3, 0.1, 100, "training", sess) + self.assertEqual(10, len(result_data)) + self.assertEqual(10, len(result_labels)) + + def testGetUnprocessedData(self): + tmp_dir = self.get_temp_dir() + wav_dir = os.path.join(tmp_dir, "wavs") + os.mkdir(wav_dir) + self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) + model_settings = { + "desired_samples": 160, + "fingerprint_size": 40, + "label_count": 4, + "window_size_samples": 100, + "window_stride_samples": 100, + "dct_coefficient_count": 40, + } + audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], + 10, 10, model_settings) + result_data, result_labels = audio_processor.get_unprocessed_data( + 10, model_settings, "training") + self.assertEqual(10, len(result_data)) + self.assertEqual(10, len(result_labels)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/examples/speech_commands/label_wav.cc b/tensorflow/examples/speech_commands/label_wav.cc new file mode 100644 index 00000000000..d8267388317 --- /dev/null +++ b/tensorflow/examples/speech_commands/label_wav.cc @@ -0,0 +1,176 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" + +// These are all common classes it's handy to reference with no namespace. +using tensorflow::Flag; +using tensorflow::Status; +using tensorflow::Tensor; +using tensorflow::int32; +using tensorflow::string; + +namespace { + +// Reads a model graph definition from disk, and creates a session object you +// can use to run it. +Status LoadGraph(const string& graph_file_name, + std::unique_ptr* session) { + tensorflow::GraphDef graph_def; + Status load_graph_status = + ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def); + if (!load_graph_status.ok()) { + return tensorflow::errors::NotFound("Failed to load compute graph at '", + graph_file_name, "'"); + } + session->reset(tensorflow::NewSession(tensorflow::SessionOptions())); + Status session_create_status = (*session)->Create(graph_def); + if (!session_create_status.ok()) { + return session_create_status; + } + return Status::OK(); +} + +// Takes a file name, and loads a list of labels from it, one per line, and +// returns a vector of the strings. +Status ReadLabelsFile(const string& file_name, std::vector* result) { + std::ifstream file(file_name); + if (!file) { + return tensorflow::errors::NotFound("Labels file ", file_name, + " not found."); + } + result->clear(); + string line; + while (std::getline(file, line)) { + result->push_back(line); + } + return Status::OK(); +} + +// Analyzes the output of the graph to retrieve the highest scores and +// their positions in the tensor. +void GetTopLabels(const std::vector& outputs, int how_many_labels, + Tensor* out_indices, Tensor* out_scores) { + const Tensor& unsorted_scores_tensor = outputs[0]; + auto unsorted_scores_flat = unsorted_scores_tensor.flat(); + std::vector> scores; + scores.reserve(unsorted_scores_flat.size()); + for (int i = 0; i < unsorted_scores_flat.size(); ++i) { + scores.push_back(std::pair({i, unsorted_scores_flat(i)})); + } + std::sort(scores.begin(), scores.end(), + [](const std::pair& left, + const std::pair& right) { + return left.second > right.second; + }); + scores.resize(how_many_labels); + Tensor sorted_indices(tensorflow::DT_INT32, {how_many_labels}); + Tensor sorted_scores(tensorflow::DT_FLOAT, {how_many_labels}); + for (int i = 0; i < scores.size(); ++i) { + sorted_indices.flat()(i) = scores[i].first; + sorted_scores.flat()(i) = scores[i].second; + } + *out_indices = sorted_indices; + *out_scores = sorted_scores; +} + +} // namespace + +int main(int argc, char* argv[]) { + string wav = ""; + string graph = ""; + string labels = ""; + string input_name = "wav_data"; + string output_name = "labels_softmax"; + int32 how_many_labels = 3; + std::vector flag_list = { + Flag("wav", &wav, "audio file to be identified"), + Flag("graph", &graph, "model to be executed"), + Flag("labels", &labels, "path to file containing labels"), + Flag("input_name", &input_name, "name of input node in model"), + Flag("output_name", &output_name, "name of output node in model"), + Flag("how_many_labels", &how_many_labels, "number of results to show"), + }; + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << usage; + return -1; + } + + // We need to call this to set up global state for TensorFlow. + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return -1; + } + + // First we load and initialize the model. + std::unique_ptr session; + Status load_graph_status = LoadGraph(graph, &session); + if (!load_graph_status.ok()) { + LOG(ERROR) << load_graph_status; + return -1; + } + + std::vector labels_list; + Status read_labels_status = ReadLabelsFile(labels, &labels_list); + if (!read_labels_status.ok()) { + LOG(ERROR) << read_labels_status; + return -1; + } + + string wav_string; + Status read_wav_status = tensorflow::ReadFileToString( + tensorflow::Env::Default(), wav, &wav_string); + if (!read_wav_status.ok()) { + LOG(ERROR) << read_wav_status; + return -1; + } + Tensor wav_tensor(tensorflow::DT_STRING, tensorflow::TensorShape({})); + wav_tensor.scalar()() = wav_string; + + // Actually run the audio through the model. + std::vector outputs; + Status run_status = + session->Run({{input_name, wav_tensor}}, {output_name}, {}, &outputs); + if (!run_status.ok()) { + LOG(ERROR) << "Running model failed: " << run_status; + return -1; + } + + Tensor indices; + Tensor scores; + GetTopLabels(outputs, how_many_labels, &indices, &scores); + tensorflow::TTypes::Flat scores_flat = scores.flat(); + tensorflow::TTypes::Flat indices_flat = indices.flat(); + for (int pos = 0; pos < how_many_labels; ++pos) { + const int label_index = indices_flat(pos); + const float score = scores_flat(pos); + LOG(INFO) << labels_list[label_index] << " (" << label_index + << "): " << score; + } + + return 0; +} diff --git a/tensorflow/examples/speech_commands/label_wav.py b/tensorflow/examples/speech_commands/label_wav.py new file mode 100644 index 00000000000..0017aec3a54 --- /dev/null +++ b/tensorflow/examples/speech_commands/label_wav.py @@ -0,0 +1,133 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Runs a trained audio graph against a WAVE file and reports the results. + +The model, labels and .wav file specified in the arguments will be loaded, and +then the predictions from running the model against the audio data will be +printed to the console. This is a useful script for sanity checking trained +models, and as an example of how to use an audio model from Python. + +Here's an example of running it: + +python tensorflow/examples/speech_commands/label_wav.py \ +--graph=/tmp/my_frozen_graph.pb \ +--labels=/tmp/speech_commands_train/conv_labels.txt \ +--wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import tensorflow as tf + +# pylint: disable=unused-import +from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio +# pylint: enable=unused-import + +FLAGS = None + + +def load_graph(filename): + """Unpersists graph from file as default graph.""" + with tf.gfile.FastGFile(filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + +def load_labels(filename): + """Read in labels, one label per line.""" + return [line.rstrip() for line in tf.gfile.GFile(filename)] + + +def run_graph(wav_data, labels, input_layer_name, output_layer_name, + num_top_predictions): + """Runs the audio data through the graph and prints predictions.""" + with tf.Session() as sess: + # Feed the audio data as input to the graph. + # predictions will contain a two-dimensional array, where one + # dimension represents the input image count, and the other has + # predictions per class + softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name) + predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data}) + + # Sort to show labels in order of confidence + top_k = predictions.argsort()[-num_top_predictions:][::-1] + for node_id in top_k: + human_string = labels[node_id] + score = predictions[node_id] + print('%s (score = %.5f)' % (human_string, score)) + + return 0 + + +def label_wav(wav, labels, graph, input_name, output_name, how_many_labels): + """Loads the model and labels, and runs the inference to print predictions.""" + if not wav or not tf.gfile.Exists(wav): + tf.logging.fatal('Audio file does not exist %s', wav) + + if not labels or not tf.gfile.Exists(labels): + tf.logging.fatal('Labels file does not exist %s', labels) + + if not graph or not tf.gfile.Exists(graph): + tf.logging.fatal('Graph file does not exist %s', graph) + + labels_list = load_labels(labels) + + # load graph, which is stored in the default session + load_graph(graph) + + with open(wav, 'rb') as wav_file: + wav_data = wav_file.read() + + run_graph(wav_data, labels_list, input_name, output_name, how_many_labels) + + +def main(_): + """Entry point for script, converts flags to arguments.""" + label_wav(FLAGS.wav, FLAGS.labels, FLAGS.graph, FLAGS.input_name, + FLAGS.output_name, FLAGS.how_many_labels) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--wav', type=str, default='', help='Audio file to be identified.') + parser.add_argument( + '--graph', type=str, default='', help='Model to use for identification.') + parser.add_argument( + '--labels', type=str, default='', help='Path to file containing labels.') + parser.add_argument( + '--input_name', + type=str, + default='wav_data:0', + help='Name of WAVE data input node in model.') + parser.add_argument( + '--output_name', + type=str, + default='labels_softmax:0', + help='Name of node outputting a prediction in the model.') + parser.add_argument( + '--how_many_labels', + type=int, + default=3, + help='Number of results to show.') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/speech_commands/label_wav_test.py b/tensorflow/examples/speech_commands/label_wav_test.py new file mode 100644 index 00000000000..80ca7747062 --- /dev/null +++ b/tensorflow/examples/speech_commands/label_wav_test.py @@ -0,0 +1,64 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for WAVE file labeling tool.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + +from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio +from tensorflow.examples.speech_commands import label_wav +from tensorflow.python.platform import test + + +class LabelWavTest(test.TestCase): + + def _getWavData(self): + with self.test_session() as sess: + sample_data = tf.zeros([1000, 2]) + wav_encoder = contrib_audio.encode_wav(sample_data, 16000) + wav_data = sess.run(wav_encoder) + return wav_data + + def _saveTestWavFile(self, filename, wav_data): + with open(filename, "wb") as f: + f.write(wav_data) + + def testLabelWav(self): + tmp_dir = self.get_temp_dir() + wav_data = self._getWavData() + wav_filename = os.path.join(tmp_dir, "wav_file.wav") + self._saveTestWavFile(wav_filename, wav_data) + input_name = "test_input" + output_name = "test_output" + graph_filename = os.path.join(tmp_dir, "test_graph.pb") + with tf.Session() as sess: + tf.placeholder(tf.string, name=input_name) + tf.zeros([1, 3], name=output_name) + with open(graph_filename, "wb") as f: + f.write(sess.graph.as_graph_def().SerializeToString()) + labels_filename = os.path.join(tmp_dir, "test_labels.txt") + with open(labels_filename, "w") as f: + f.write("a\nb\nc\n") + label_wav.label_wav(wav_filename, labels_filename, graph_filename, + input_name + ":0", output_name + ":0", 3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py new file mode 100644 index 00000000000..9eafb933fb1 --- /dev/null +++ b/tensorflow/examples/speech_commands/models.py @@ -0,0 +1,378 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model definitions for simple speech recognition. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import tensorflow as tf + + +def prepare_model_settings(label_count, sample_rate, clip_duration_ms, + window_size_ms, window_stride_ms, + dct_coefficient_count): + """Calculates common settings needed for all models. + + Args: + label_count: How many classes are to be recognized. + sample_rate: Number of audio samples per second. + clip_duration_ms: Length of each audio clip to be analyzed. + window_size_ms: Duration of frequency analysis window. + window_stride_ms: How far to move in time between frequency windows. + dct_coefficient_count: Number of frequency bins to use for analysis. + + Returns: + Dictionary containing common settings. + """ + desired_samples = int(sample_rate * clip_duration_ms / 1000) + window_size_samples = int(sample_rate * window_size_ms / 1000) + window_stride_samples = int(sample_rate * window_stride_ms / 1000) + length_minus_window = (desired_samples - window_size_samples) + if length_minus_window < 0: + spectrogram_length = 0 + else: + spectrogram_length = 1 + int(length_minus_window / window_stride_samples) + fingerprint_size = dct_coefficient_count * spectrogram_length + return { + 'desired_samples': desired_samples, + 'window_size_samples': window_size_samples, + 'window_stride_samples': window_stride_samples, + 'spectrogram_length': spectrogram_length, + 'dct_coefficient_count': dct_coefficient_count, + 'fingerprint_size': fingerprint_size, + 'label_count': label_count, + 'sample_rate': sample_rate, + } + + +def create_model(fingerprint_input, model_settings, model_architecture, + is_training): + """Builds a model of the requested architecture compatible with the settings. + + There are many possible ways of deriving predictions from a spectrogram + input, so this function provides an abstract interface for creating different + kinds of models in a black-box way. You need to pass in a TensorFlow node as + the 'fingerprint' input, and this should output a batch of 1D features that + describe the audio. Typically this will be derived from a spectrogram that's + been run through an MFCC, but in theory it can be any feature vector of the + size specified in model_settings['fingerprint_size']. + + The function will build the graph it needs in the current TensorFlow graph, + and return the tensorflow output that will contain the 'logits' input to the + softmax prediction process. If training flag is on, it will also return a + placeholder node that can be used to control the dropout amount. + + See the implementations below for the possible model architectures that can be + requested. + + Args: + fingerprint_input: TensorFlow node that will output audio feature vectors. + model_settings: Dictionary of information about the model. + model_architecture: String specifying which kind of model to create. + is_training: Whether the model is going to be used for training. + + Returns: + TensorFlow node outputting logits results, and optionally a dropout + placeholder. + + Raises: + Exception: If the architecture type isn't recognized. + """ + if model_architecture == 'single_fc': + return create_single_fc_model(fingerprint_input, model_settings, + is_training) + elif model_architecture == 'conv': + return create_conv_model(fingerprint_input, model_settings, is_training) + elif model_architecture == 'low_latency_conv': + return create_low_latency_conv_model(fingerprint_input, model_settings, + is_training) + else: + raise Exception('model_architecture argument "' + model_architecture + + '" not recognized, should be one of "single_fc", "conv",' + + ' or "low_latency_conv"') + + +def load_variables_from_checkpoint(sess, start_checkpoint): + """Utility function to centralize checkpoint restoration. + + Args: + sess: TensorFlow session. + start_checkpoint: Path to saved checkpoint on disk. + """ + saver = tf.train.Saver(tf.global_variables()) + saver.restore(sess, start_checkpoint) + + +def create_single_fc_model(fingerprint_input, model_settings, is_training): + """Builds a model with a single hidden fully-connected layer. + + This is a very simple model with just one matmul and bias layer. As you'd + expect, it doesn't produce very accurate results, but it is very fast and + simple, so it's useful for sanity testing. + + Here's the layout of the graph: + + (fingerprint_input) + v + [MatMul]<-(weights) + v + [BiasAdd]<-(bias) + v + + Args: + fingerprint_input: TensorFlow node that will output audio feature vectors. + model_settings: Dictionary of information about the model. + is_training: Whether the model is going to be used for training. + + Returns: + TensorFlow node outputting logits results, and optionally a dropout + placeholder. + """ + if is_training: + dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') + fingerprint_size = model_settings['fingerprint_size'] + label_count = model_settings['label_count'] + weights = tf.Variable( + tf.truncated_normal([fingerprint_size, label_count], stddev=0.001)) + bias = tf.Variable(tf.zeros([label_count])) + logits = tf.matmul(fingerprint_input, weights) + bias + if is_training: + return logits, dropout_prob + else: + return logits + + +def create_conv_model(fingerprint_input, model_settings, is_training): + """Builds a standard convolutional model. + + This is roughly the network labeled as 'cnn-trad-fpool3' in the + 'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper: + http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf + + Here's the layout of the graph: + + (fingerprint_input) + v + [Conv2D]<-(weights) + v + [BiasAdd]<-(bias) + v + [Relu] + v + [MaxPool] + v + [Conv2D]<-(weights) + v + [BiasAdd]<-(bias) + v + [Relu] + v + [MaxPool] + v + [MatMul]<-(weights) + v + [BiasAdd]<-(bias) + v + + This produces fairly good quality results, but can involve a large number of + weight parameters and computations. For a cheaper alternative from the same + paper with slightly less accuracy, see 'low_latency_conv' below. + + During training, dropout nodes are introduced after each relu, controlled by a + placeholder. + + Args: + fingerprint_input: TensorFlow node that will output audio feature vectors. + model_settings: Dictionary of information about the model. + is_training: Whether the model is going to be used for training. + + Returns: + TensorFlow node outputting logits results, and optionally a dropout + placeholder. + """ + if is_training: + dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') + input_frequency_size = model_settings['dct_coefficient_count'] + input_time_size = model_settings['spectrogram_length'] + fingerprint_4d = tf.reshape(fingerprint_input, + [-1, input_time_size, input_frequency_size, 1]) + first_filter_width = 8 + first_filter_height = 20 + first_filter_count = 64 + first_weights = tf.Variable( + tf.truncated_normal( + [first_filter_height, first_filter_width, 1, first_filter_count], + stddev=0.01)) + first_bias = tf.Variable(tf.zeros([first_filter_count])) + first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1], + 'SAME') + first_bias + first_relu = tf.nn.relu(first_conv) + if is_training: + first_dropout = tf.nn.dropout(first_relu, dropout_prob) + else: + first_dropout = first_relu + max_pool = tf.nn.max_pool(first_dropout, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') + second_filter_width = 4 + second_filter_height = 10 + second_filter_count = 64 + second_weights = tf.Variable( + tf.truncated_normal( + [ + second_filter_height, second_filter_width, first_filter_count, + second_filter_count + ], + stddev=0.01)) + second_bias = tf.Variable(tf.zeros([second_filter_count])) + second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1], + 'SAME') + second_bias + second_relu = tf.nn.relu(second_conv) + if is_training: + second_dropout = tf.nn.dropout(second_relu, dropout_prob) + else: + second_dropout = second_relu + second_conv_shape = second_dropout.get_shape() + second_conv_output_width = second_conv_shape[2] + second_conv_output_height = second_conv_shape[1] + second_conv_element_count = int( + second_conv_output_width * second_conv_output_height * + second_filter_count) + flattened_second_conv = tf.reshape(second_dropout, + [-1, second_conv_element_count]) + label_count = model_settings['label_count'] + final_fc_weights = tf.Variable( + tf.truncated_normal( + [second_conv_element_count, label_count], stddev=0.01)) + final_fc_bias = tf.Variable(tf.zeros([label_count])) + final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias + if is_training: + return final_fc, dropout_prob + else: + return final_fc + + +def create_low_latency_conv_model(fingerprint_input, model_settings, + is_training): + """Builds a convolutional model with low compute requirements. + + This is roughly the network labeled as 'cnn-one-fstride4' in the + 'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper: + http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf + + Here's the layout of the graph: + + (fingerprint_input) + v + [Conv2D]<-(weights) + v + [BiasAdd]<-(bias) + v + [Relu] + v + [MatMul]<-(weights) + v + [BiasAdd]<-(bias) + v + [MatMul]<-(weights) + v + [BiasAdd]<-(bias) + v + [MatMul]<-(weights) + v + [BiasAdd]<-(bias) + v + + This produces slightly lower quality results than the 'conv' model, but needs + fewer weight parameters and computations. + + During training, dropout nodes are introduced after the relu, controlled by a + placeholder. + + Args: + fingerprint_input: TensorFlow node that will output audio feature vectors. + model_settings: Dictionary of information about the model. + is_training: Whether the model is going to be used for training. + + Returns: + TensorFlow node outputting logits results, and optionally a dropout + placeholder. + """ + if is_training: + dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') + input_frequency_size = model_settings['dct_coefficient_count'] + input_time_size = model_settings['spectrogram_length'] + fingerprint_4d = tf.reshape(fingerprint_input, + [-1, input_time_size, input_frequency_size, 1]) + first_filter_width = 8 + first_filter_height = input_time_size + first_filter_count = 186 + first_filter_stride_x = 1 + first_filter_stride_y = 4 + first_weights = tf.Variable( + tf.truncated_normal( + [first_filter_height, first_filter_width, 1, first_filter_count], + stddev=0.01)) + first_bias = tf.Variable(tf.zeros([first_filter_count])) + first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [ + 1, first_filter_stride_y, first_filter_stride_x, 1 + ], 'VALID') + first_bias + first_relu = tf.nn.relu(first_conv) + if is_training: + first_dropout = tf.nn.dropout(first_relu, dropout_prob) + else: + first_dropout = first_relu + first_conv_output_width = math.floor( + (input_frequency_size - first_filter_width + first_filter_stride_x) / + first_filter_stride_x) + first_conv_output_height = math.floor( + (input_time_size - first_filter_height + first_filter_stride_y) / + first_filter_stride_y) + first_conv_element_count = int( + first_conv_output_width * first_conv_output_height * first_filter_count) + flattened_first_conv = tf.reshape(first_dropout, + [-1, first_conv_element_count]) + first_fc_output_channels = 128 + first_fc_weights = tf.Variable( + tf.truncated_normal( + [first_conv_element_count, first_fc_output_channels], stddev=0.01)) + first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels])) + first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias + if is_training: + second_fc_input = tf.nn.dropout(first_fc, dropout_prob) + else: + second_fc_input = first_fc + second_fc_output_channels = 128 + second_fc_weights = tf.Variable( + tf.truncated_normal( + [first_fc_output_channels, second_fc_output_channels], stddev=0.01)) + second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels])) + second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias + if is_training: + final_fc_input = tf.nn.dropout(second_fc, dropout_prob) + else: + final_fc_input = second_fc + label_count = model_settings['label_count'] + final_fc_weights = tf.Variable( + tf.truncated_normal( + [second_fc_output_channels, label_count], stddev=0.01)) + final_fc_bias = tf.Variable(tf.zeros([label_count])) + final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias + if is_training: + return final_fc, dropout_prob + else: + return final_fc diff --git a/tensorflow/examples/speech_commands/models_test.py b/tensorflow/examples/speech_commands/models_test.py new file mode 100644 index 00000000000..80c795367fa --- /dev/null +++ b/tensorflow/examples/speech_commands/models_test.py @@ -0,0 +1,86 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for speech commands models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.examples.speech_commands import models +from tensorflow.python.platform import test + + +class ModelsTest(test.TestCase): + + def testPrepareModelSettings(self): + self.assertIsNotNone( + models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)) + + def testCreateModelConvTraining(self): + model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + with self.test_session() as sess: + fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) + logits, dropout_prob = models.create_model(fingerprint_input, + model_settings, "conv", True) + self.assertIsNotNone(logits) + self.assertIsNotNone(dropout_prob) + self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) + self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name)) + + def testCreateModelConvInference(self): + model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + with self.test_session() as sess: + fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) + logits = models.create_model(fingerprint_input, model_settings, "conv", + False) + self.assertIsNotNone(logits) + self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) + + def testCreateModelLowLatencyConvTraining(self): + model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + with self.test_session() as sess: + fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) + logits, dropout_prob = models.create_model( + fingerprint_input, model_settings, "low_latency_conv", True) + self.assertIsNotNone(logits) + self.assertIsNotNone(dropout_prob) + self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) + self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name)) + + def testCreateModelFullyConnectedTraining(self): + model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + with self.test_session() as sess: + fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) + logits, dropout_prob = models.create_model( + fingerprint_input, model_settings, "single_fc", True) + self.assertIsNotNone(logits) + self.assertIsNotNone(dropout_prob) + self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) + self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name)) + + def testCreateModelBadArchitecture(self): + model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + with self.test_session(): + fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) + with self.assertRaises(Exception) as e: + models.create_model(fingerprint_input, model_settings, + "bad_architecture", True) + self.assertTrue("not recognized" in str(e.exception)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/examples/speech_commands/recognize_commands.cc b/tensorflow/examples/speech_commands/recognize_commands.cc new file mode 100644 index 00000000000..1d2c19ff15d --- /dev/null +++ b/tensorflow/examples/speech_commands/recognize_commands.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/speech_commands/recognize_commands.h" + +namespace tensorflow { + +RecognizeCommands::RecognizeCommands(const std::vector& labels, + int32 average_window_duration_ms, + float detection_threshold, + int32 suppression_ms, int32 minimum_count) + : labels_(labels), + average_window_duration_ms_(average_window_duration_ms), + detection_threshold_(detection_threshold), + suppression_ms_(suppression_ms), + minimum_count_(minimum_count) { + labels_count_ = labels.size(); + previous_top_label_ = "_silence_"; + previous_top_label_time_ = std::numeric_limits::min(); +} + +Status RecognizeCommands::ProcessLatestResults(const Tensor& latest_results, + const int64 current_time_ms, + string* found_command, + float* score, + bool* is_new_command) { + if (latest_results.NumElements() != labels_count_) { + return errors::InvalidArgument( + "The results for recognition should contain ", labels_count_, + " elements, but there are ", latest_results.NumElements()); + } + + if ((!previous_results_.empty()) && + (current_time_ms < previous_results_.front().first)) { + return errors::InvalidArgument( + "Results must be fed in increasing time order, but received a " + "timestamp of ", + current_time_ms, " that was earlier than the previous one of ", + previous_results_.front().first); + } + + // Add the latest results to the head of the queue. + previous_results_.push_back({current_time_ms, latest_results}); + + // Prune any earlier results that are too old for the averaging window. + const int64 time_limit = current_time_ms - average_window_duration_ms_; + while (previous_results_.front().first < time_limit) { + previous_results_.pop_front(); + } + + // If there are too few results, assume the result will be unreliable and + // bail. + const int64 how_many_results = previous_results_.size(); + const int64 earliest_time = previous_results_.front().first; + const int64 samples_duration = current_time_ms - earliest_time; + if ((how_many_results < minimum_count_) || + (samples_duration < (average_window_duration_ms_ / 4))) { + *found_command = previous_top_label_; + *score = 0.0f; + *is_new_command = false; + return Status::OK(); + } + + // Calculate the average score across all the results in the window. + std::vector average_scores(labels_count_); + for (const auto& previous_result : previous_results_) { + const Tensor& scores_tensor = previous_result.second; + auto scores_flat = scores_tensor.flat(); + for (int i = 0; i < scores_flat.size(); ++i) { + average_scores[i] += scores_flat(i) / how_many_results; + } + } + + // Sort the averaged results in descending score order. + std::vector> sorted_average_scores; + sorted_average_scores.reserve(labels_count_); + for (int i = 0; i < labels_count_; ++i) { + sorted_average_scores.push_back( + std::pair({i, average_scores[i]})); + } + std::sort(sorted_average_scores.begin(), sorted_average_scores.end(), + [](const std::pair& left, + const std::pair& right) { + return left.second > right.second; + }); + + // See if the latest top score is enough to trigger a detection. + const int current_top_index = sorted_average_scores[0].first; + const string current_top_label = labels_[current_top_index]; + const float current_top_score = sorted_average_scores[0].second; + // If we've recently had another label trigger, assume one that occurs too + // soon afterwards is a bad result. + int64 time_since_last_top; + if ((previous_top_label_ == "_silence_") || + (previous_top_label_time_ == std::numeric_limits::min())) { + time_since_last_top = std::numeric_limits::max(); + } else { + time_since_last_top = current_time_ms - previous_top_label_time_; + } + if ((current_top_score > detection_threshold_) && + (current_top_label != previous_top_label_) && + (time_since_last_top > suppression_ms_)) { + previous_top_label_ = current_top_label; + previous_top_label_time_ = current_time_ms; + *is_new_command = true; + } else { + *is_new_command = false; + } + *found_command = current_top_label; + *score = current_top_score; + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/examples/speech_commands/recognize_commands.h b/tensorflow/examples/speech_commands/recognize_commands.h new file mode 100644 index 00000000000..7f8041f9ed3 --- /dev/null +++ b/tensorflow/examples/speech_commands/recognize_commands.h @@ -0,0 +1,79 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// This class is designed to apply a very primitive decoding model on top of the +// instantaneous results from running an audio recognition model on a single +// window of samples. It applies smoothing over time so that noisy individual +// label scores are averaged, increasing the confidence that apparent matches +// are real. +// To use it, you should create a class object with the configuration you +// want, and then feed results from running a TensorFlow model into the +// processing method. The timestamp for each subsequent call should be +// increasing from the previous, since the class is designed to process a stream +// of data over time. +class RecognizeCommands { + public: + // labels should be a list of the strings associated with each one-hot score. + // The window duration controls the smoothing. Longer durations will give a + // higher confidence that the results are correct, but may miss some commands. + // The detection threshold has a similar effect, with high values increasing + // the precision at the cost of recall. The minimum count controls how many + // results need to be in the averaging window before it's seen as a reliable + // average. This prevents erroneous results when the averaging window is + // initially being populated for example. The suppression argument disables + // further recognitions for a set time after one has been triggered, which can + // help reduce spurious recognitions. + explicit RecognizeCommands(const std::vector& labels, + int32 average_window_duration_ms = 1000, + float detection_threshold = 0.2, + int32 suppression_ms = 500, + int32 minimum_count = 3); + + // Call this with the results of running a model on sample data. + Status ProcessLatestResults(const Tensor& latest_results, + const int64 current_time_ms, + string* found_command, float* score, + bool* is_new_command); + + private: + // Configuration + std::vector labels_; + int32 average_window_duration_ms_; + float detection_threshold_; + int32 suppression_ms_; + int32 minimum_count_; + + // Working variables + std::deque> previous_results_; + string previous_top_label_; + int64 labels_count_; + int64 previous_top_label_time_; +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_ diff --git a/tensorflow/examples/speech_commands/recognize_commands_test.cc b/tensorflow/examples/speech_commands/recognize_commands_test.cc new file mode 100644 index 00000000000..4a5ee0fe4c7 --- /dev/null +++ b/tensorflow/examples/speech_commands/recognize_commands_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/speech_commands/recognize_commands.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(RecognizeCommandsTest, Basic) { + RecognizeCommands recognize_commands({"_silence_", "a", "b"}); + + Tensor results(DT_FLOAT, {3}); + test::FillValues(&results, {1.0f, 0.0f, 0.0f}); + + string found_command; + float score; + bool is_new_command; + TF_EXPECT_OK(recognize_commands.ProcessLatestResults( + results, 0, &found_command, &score, &is_new_command)); +} + +TEST(RecognizeCommandsTest, FindCommands) { + RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f); + + Tensor results(DT_FLOAT, {3}); + + test::FillValues(&results, {0.0f, 1.0f, 0.0f}); + bool has_found_new_command = false; + string new_command; + for (int i = 0; i < 10; ++i) { + string found_command; + float score; + bool is_new_command; + int64 current_time_ms = 0 + (i * 100); + TF_EXPECT_OK(recognize_commands.ProcessLatestResults( + results, current_time_ms, &found_command, &score, &is_new_command)); + if (is_new_command) { + EXPECT_FALSE(has_found_new_command); + has_found_new_command = true; + new_command = found_command; + } + } + EXPECT_TRUE(has_found_new_command); + EXPECT_EQ("a", new_command); + + test::FillValues(&results, {0.0f, 0.0f, 1.0f}); + has_found_new_command = false; + new_command = ""; + for (int i = 0; i < 10; ++i) { + string found_command; + float score; + bool is_new_command; + int64 current_time_ms = 1000 + (i * 100); + TF_EXPECT_OK(recognize_commands.ProcessLatestResults( + results, current_time_ms, &found_command, &score, &is_new_command)); + if (is_new_command) { + EXPECT_FALSE(has_found_new_command); + has_found_new_command = true; + new_command = found_command; + } + } + EXPECT_TRUE(has_found_new_command); + EXPECT_EQ("b", new_command); +} + +TEST(RecognizeCommandsTest, BadInputLength) { + RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f); + + Tensor bad_results(DT_FLOAT, {2}); + test::FillValues(&bad_results, {1.0f, 0.0f}); + + string found_command; + float score; + bool is_new_command; + EXPECT_FALSE(recognize_commands + .ProcessLatestResults(bad_results, 0, &found_command, &score, + &is_new_command) + .ok()); +} + +TEST(RecognizeCommandsTest, BadInputTimes) { + RecognizeCommands recognize_commands({"_silence_", "a", "b"}, 1000, 0.2f); + + Tensor results(DT_FLOAT, {3}); + test::FillValues(&results, {1.0f, 0.0f, 0.0f}); + + string found_command; + float score; + bool is_new_command; + TF_EXPECT_OK(recognize_commands.ProcessLatestResults( + results, 100, &found_command, &score, &is_new_command)); + EXPECT_FALSE(recognize_commands + .ProcessLatestResults(results, 0, &found_command, &score, + &is_new_command) + .ok()); +} + +} // namespace tensorflow diff --git a/tensorflow/examples/speech_commands/test_streaming_accuracy.cc b/tensorflow/examples/speech_commands/test_streaming_accuracy.cc new file mode 100644 index 00000000000..5df944f2096 --- /dev/null +++ b/tensorflow/examples/speech_commands/test_streaming_accuracy.cc @@ -0,0 +1,310 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/* + +Tool to create accuracy statistics from running an audio recognition model on a +continuous stream of samples. + +This is designed to be an environment for running experiments on new models and +settings to understand the effects they will have in a real application. You +need to supply it with a long audio file containing sounds you want to recognize +and a text file listing the labels of each sound along with the time they occur. +With this information, and a frozen model, the tool will process the audio +stream, apply the model, and keep track of how many mistakes and successes the +model achieved. + +The matched percentage is the number of sounds that were correctly classified, +as a percentage of the total number of sounds listed in the ground truth file. +A correct classification is when the right label is chosen within a short time +of the expected ground truth, where the time tolerance is controlled by the +'time_tolerance_ms' command line flag. + +The wrong percentage is how many sounds triggered a detection (the classifier +figured out it wasn't silence or background noise), but the detected class was +wrong. This is also a percentage of the total number of ground truth sounds. + +The false positive percentage is how many sounds were detected when there was +only silence or background noise. This is also expressed as a percentage of the +total number of ground truth sounds, though since it can be large it may go +above 100%. + +The easiest way to get an audio file and labels to test with is by using the +'generate_streaming_test_wav' script. This will synthesize a test file with +randomly placed sounds and background noise, and output a text file with the +ground truth. + +If you want to test natural data, you need to use a .wav with the same sample +rate as your model (often 16,000 samples per second), and note down where the +sounds occur in time. Save this information out as a comma-separated text file, +where the first column is the label and the second is the time in seconds from +the start of the file that it occurs. + +Here's an example of how to run the tool: + +bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \ +--wav=/tmp/streaming_test_bg.wav \ +--graph=/tmp/conv_frozen.pb \ +--labels=/tmp/speech_commands_train/conv_labels.txt \ +--ground_truth=/tmp/streaming_test_labels.txt --verbose \ +--clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \ +--suppression_ms=500 --time_tolerance_ms=1500 + + */ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/wav/wav_io.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/examples/speech_commands/accuracy_utils.h" +#include "tensorflow/examples/speech_commands/recognize_commands.h" + +// These are all common classes it's handy to reference with no namespace. +using tensorflow::Flag; +using tensorflow::Status; +using tensorflow::Tensor; +using tensorflow::int32; +using tensorflow::string; + +namespace { + +// Reads a model graph definition from disk, and creates a session object you +// can use to run it. +Status LoadGraph(const string& graph_file_name, + std::unique_ptr* session) { + tensorflow::GraphDef graph_def; + Status load_graph_status = + ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def); + if (!load_graph_status.ok()) { + return tensorflow::errors::NotFound("Failed to load compute graph at '", + graph_file_name, "'"); + } + session->reset(tensorflow::NewSession(tensorflow::SessionOptions())); + Status session_create_status = (*session)->Create(graph_def); + if (!session_create_status.ok()) { + return session_create_status; + } + return Status::OK(); +} + +// Takes a file name, and loads a list of labels from it, one per line, and +// returns a vector of the strings. +Status ReadLabelsFile(const string& file_name, std::vector* result) { + std::ifstream file(file_name); + if (!file) { + return tensorflow::errors::NotFound("Labels file '", file_name, + "' not found."); + } + result->clear(); + string line; + while (std::getline(file, line)) { + result->push_back(line); + } + return Status::OK(); +} + +} // namespace + +int main(int argc, char* argv[]) { + string wav = ""; + string graph = ""; + string labels = ""; + string ground_truth = ""; + string input_data_name = "decoded_sample_data:0"; + string input_rate_name = "decoded_sample_data:1"; + string output_name = "labels_softmax"; + int32 clip_duration_ms = 1000; + int32 sample_stride_ms = 30; + int32 average_window_ms = 500; + int32 time_tolerance_ms = 750; + int32 suppression_ms = 1500; + float detection_threshold = 0.7f; + bool verbose = false; + std::vector flag_list = { + Flag("wav", &wav, "audio file to be identified"), + Flag("graph", &graph, "model to be executed"), + Flag("labels", &labels, "path to file containing labels"), + Flag("ground_truth", &ground_truth, + "path to file containing correct times and labels of words in the " + "audio as , lines"), + Flag("input_data_name", &input_data_name, + "name of input data node in model"), + Flag("input_rate_name", &input_rate_name, + "name of input sample rate node in model"), + Flag("output_name", &output_name, "name of output node in model"), + Flag("clip_duration_ms", &clip_duration_ms, + "length of recognition window"), + Flag("average_window_ms", &average_window_ms, + "length of window to smooth results over"), + Flag("time_tolerance_ms", &time_tolerance_ms, + "maximum gap allowed between a recognition and ground truth"), + Flag("suppression_ms", &suppression_ms, + "how long to ignore others for after a recognition"), + Flag("sample_stride_ms", &sample_stride_ms, + "how often to run recognition"), + Flag("detection_threshold", &detection_threshold, + "what score is required to trigger detection of a word"), + Flag("verbose", &verbose, "whether to log extra debugging information"), + }; + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << usage; + return -1; + } + + // We need to call this to set up global state for TensorFlow. + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return -1; + } + + // First we load and initialize the model. + std::unique_ptr session; + Status load_graph_status = LoadGraph(graph, &session); + if (!load_graph_status.ok()) { + LOG(ERROR) << load_graph_status; + return -1; + } + + std::vector labels_list; + Status read_labels_status = ReadLabelsFile(labels, &labels_list); + if (!read_labels_status.ok()) { + LOG(ERROR) << read_labels_status; + return -1; + } + + std::vector> ground_truth_list; + Status read_ground_truth_status = + tensorflow::ReadGroundTruthFile(ground_truth, &ground_truth_list); + if (!read_ground_truth_status.ok()) { + LOG(ERROR) << read_ground_truth_status; + return -1; + } + + string wav_string; + Status read_wav_status = tensorflow::ReadFileToString( + tensorflow::Env::Default(), wav, &wav_string); + if (!read_wav_status.ok()) { + LOG(ERROR) << read_wav_status; + return -1; + } + std::vector audio_data; + uint32 sample_count; + uint16 channel_count; + uint32 sample_rate; + Status decode_wav_status = tensorflow::wav::DecodeLin16WaveAsFloatVector( + wav_string, &audio_data, &sample_count, &channel_count, &sample_rate); + if (!decode_wav_status.ok()) { + LOG(ERROR) << decode_wav_status; + return -1; + } + if (channel_count != 1) { + LOG(ERROR) << "Only mono .wav files can be used, but input has " + << channel_count << " channels."; + return -1; + } + + const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000; + const int64 sample_stride_samples = (sample_stride_ms * sample_rate) / 1000; + Tensor audio_data_tensor(tensorflow::DT_FLOAT, + tensorflow::TensorShape({clip_duration_samples, 1})); + + Tensor sample_rate_tensor(tensorflow::DT_INT32, tensorflow::TensorShape({})); + sample_rate_tensor.scalar()() = sample_rate; + + tensorflow::RecognizeCommands recognize_commands( + labels_list, average_window_ms, detection_threshold, suppression_ms); + + std::vector> all_found_words; + tensorflow::StreamingAccuracyStats previous_stats; + + const int64 audio_data_end = (sample_count - clip_duration_ms); + for (int64 audio_data_offset = 0; audio_data_offset < audio_data_end; + audio_data_offset += sample_stride_samples) { + const float* input_start = &(audio_data[audio_data_offset]); + const float* input_end = input_start + clip_duration_samples; + std::copy(input_start, input_end, audio_data_tensor.flat().data()); + + // Actually run the audio through the model. + std::vector outputs; + Status run_status = session->Run({{input_data_name, audio_data_tensor}, + {input_rate_name, sample_rate_tensor}}, + {output_name}, {}, &outputs); + if (!run_status.ok()) { + LOG(ERROR) << "Running model failed: " << run_status; + return -1; + } + + const int64 current_time_ms = (audio_data_offset * 1000) / sample_rate; + string found_command; + float score; + bool is_new_command; + Status recognize_status = recognize_commands.ProcessLatestResults( + outputs[0], current_time_ms, &found_command, &score, &is_new_command); + if (!recognize_status.ok()) { + LOG(ERROR) << "Recognition processing failed: " << recognize_status; + return -1; + } + + if (is_new_command && (found_command != "_silence_")) { + all_found_words.push_back({found_command, current_time_ms}); + if (verbose) { + tensorflow::StreamingAccuracyStats stats; + tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, + current_time_ms, time_tolerance_ms, + &stats); + int32 false_positive_delta = stats.how_many_false_positives - + previous_stats.how_many_false_positives; + int32 correct_delta = stats.how_many_correct_words - + previous_stats.how_many_correct_words; + int32 wrong_delta = + stats.how_many_wrong_words - previous_stats.how_many_wrong_words; + string recognition_state; + if (false_positive_delta == 1) { + recognition_state = " (False Positive)"; + } else if (correct_delta == 1) { + recognition_state = " (Correct)"; + } else if (wrong_delta == 1) { + recognition_state = " (Wrong)"; + } else { + LOG(ERROR) << "Unexpected state in statistics"; + } + LOG(INFO) << current_time_ms << "ms: " << found_command << ": " << score + << recognition_state; + previous_stats = stats; + tensorflow::PrintAccuracyStats(stats); + } + } + } + + tensorflow::StreamingAccuracyStats stats; + tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, -1, + time_tolerance_ms, &stats); + tensorflow::PrintAccuracyStats(stats); + + return 0; +} diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py new file mode 100644 index 00000000000..925607a1fb8 --- /dev/null +++ b/tensorflow/examples/speech_commands/train.py @@ -0,0 +1,427 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Simple speech recognition to spot a limited number of keywords. + +This is a self-contained example script that will train a very basic audio +recognition model in TensorFlow. It can download the necessary training data, +and runs with reasonable defaults to train within a few hours even only using a +CPU. For more information see http://tensorflow.org/tutorials/audio_recognition. + +It is intended as an introduction to using neural networks for audio +recognition, and is not a full speech recognition system. For more advanced +speech systems, I recommend looking into Kaldi. This network uses a keyword +detection style to spot discrete words from a small vocabulary, consisting of +"yes", "no", "up", "down", "left", "right", "on", "off", "stop", and "go". + +To run the training process, use: + +bazel run tensorflow/examples/speech_commands:train + +This will write out checkpoints to /tmp/speech_commands_train/, and will +download over 1GB of open source training data, so you'll need enough free space +and a good internet connection. The default data is a collection of thousands of +one-second .wav files, each containing one spoken word. This data set is +collected from https://aiyprojects.withgoogle.com/open_speech_recording, please +consider contributing to help improve this and other models! + +As training progresses, it will print out its accuracy metrics, which should +rise above 90% by the end. Once it's complete, you can run the freeze script to +get a binary GraphDef that you can easily deploy on mobile applications. + +If you want to train on your own data, you'll need to create .wavs with your +recordings, all at a consistent length, and then arrange them into subfolders +organized by label. For example, here's a possible file structure: + +my_wavs > + up > + audio_0.wav + audio_1.wav + down > + audio_2.wav + audio_3.wav + other> + audio_4.wav + audio_5.wav + +You'll also need to tell the script what labels to look for, using the +`--wanted_words` argument. In this case, 'up,down' might be what you want, and +the audio in the 'other' folder would be used to train an 'unknown' category. + +To pull this all together, you'd run: + +bazel run tensorflow/examples/speech_commands:train -- \ +--data_dir=my_wavs --wanted_words=up,down + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os.path +import sys + +import numpy as np +import tensorflow as tf + +import input_data +import models +from tensorflow.python.platform import gfile + +FLAGS = None + + +def main(_): + # We want to see all the logging messages for this tutorial. + tf.logging.set_verbosity(tf.logging.INFO) + + # Start a new TensorFlow session. + sess = tf.InteractiveSession() + + # Begin by making sure we have the training data we need. If you already have + # training data of your own, use `--data_url= ` on the command line to avoid + # downloading. + model_settings = models.prepare_model_settings( + len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))), + FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms, + FLAGS.window_stride_ms, FLAGS.dct_coefficient_count) + audio_processor = input_data.AudioProcessor( + FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage, + FLAGS.unknown_percentage, + FLAGS.wanted_words.split(','), FLAGS.validation_percentage, + FLAGS.testing_percentage, model_settings) + fingerprint_size = model_settings['fingerprint_size'] + label_count = model_settings['label_count'] + time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000) + # Figure out the learning rates for each training phase. Since it's often + # effective to have high learning rates at the start of training, followed by + # lower levels towards the end, the number of steps and learning rates can be + # specified as comma-separated lists to define the rate at each stage. For + # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001 + # will run 13,000 training loops in total, with a rate of 0.001 for the first + # 10,000, and 0.0001 for the final 3,000. + training_steps_list = map(int, FLAGS.how_many_training_steps.split(',')) + learning_rates_list = map(float, FLAGS.learning_rate.split(',')) + if len(training_steps_list) != len(learning_rates_list): + raise Exception( + '--how_many_training_steps and --learning_rate must be equal length ' + 'lists, but are %d and %d long instead' % (len(training_steps_list), + len(learning_rates_list))) + + fingerprint_input = tf.placeholder( + tf.float32, [None, fingerprint_size], name='fingerprint_input') + + logits, dropout_prob = models.create_model( + fingerprint_input, + model_settings, + FLAGS.model_architecture, + is_training=True) + + # Define loss and optimizer + ground_truth_input = tf.placeholder( + tf.float32, [None, label_count], name='groundtruth_input') + + # Optionally we can add runtime checks to spot when NaNs or other symptoms of + # numerical errors start occurring during training. + control_dependencies = [] + if FLAGS.check_nans: + checks = tf.add_check_numerics_ops() + control_dependencies = [checks] + + # Create the back propagation and training evaluation machinery in the graph. + with tf.name_scope('cross_entropy'): + cross_entropy_mean = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits( + labels=ground_truth_input, logits=logits)) + tf.summary.scalar('cross_entropy', cross_entropy_mean) + with tf.name_scope('train'), tf.control_dependencies(control_dependencies): + learning_rate_input = tf.placeholder( + tf.float32, [], name='learning_rate_input') + train_step = tf.train.GradientDescentOptimizer( + learning_rate_input).minimize(cross_entropy_mean) + predicted_indices = tf.argmax(logits, 1) + expected_indices = tf.argmax(ground_truth_input, 1) + correct_prediction = tf.equal(predicted_indices, expected_indices) + confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices) + evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + tf.summary.scalar('accuracy', evaluation_step) + + global_step = tf.contrib.framework.get_or_create_global_step() + increment_global_step = tf.assign(global_step, global_step + 1) + + saver = tf.train.Saver(tf.global_variables()) + + # Merge all the summaries and write them out to /tmp/retrain_logs (by default) + merged_summaries = tf.summary.merge_all() + train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', + sess.graph) + validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation') + + tf.global_variables_initializer().run() + + start_step = 1 + + if FLAGS.start_checkpoint: + models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) + start_step = global_step.eval(session=sess) + + tf.logging.info('Training from step: %d ', start_step) + + # Save graph.pbtxt. + tf.train.write_graph(sess.graph_def, FLAGS.train_dir, + FLAGS.model_architecture + '.pbtxt') + + # Save list of words. + with gfile.GFile( + os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'), + 'w') as f: + f.write('\n'.join(audio_processor.words_list)) + + # Training loop. + training_steps_max = np.sum(training_steps_list) + for training_step in xrange(start_step, training_steps_max + 1): + # Figure out what the current learning rate is. + training_steps_sum = 0 + for i in range(len(training_steps_list)): + training_steps_sum += training_steps_list[i] + if training_step <= training_steps_sum: + learning_rate_value = learning_rates_list[i] + break + # Pull the audio samples we'll use for training. + train_fingerprints, train_ground_truth = audio_processor.get_data( + FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency, + FLAGS.background_volume, time_shift_samples, 'training', sess) + # Run the graph with this batch of training data. + train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run( + [ + merged_summaries, evaluation_step, cross_entropy_mean, train_step, + increment_global_step + ], + feed_dict={ + fingerprint_input: train_fingerprints, + ground_truth_input: train_ground_truth, + learning_rate_input: learning_rate_value, + dropout_prob: 0.5 + }) + train_writer.add_summary(train_summary, training_step) + tf.logging.info('Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' % + (training_step, learning_rate_value, train_accuracy * 100, + cross_entropy_value)) + is_last_step = (training_step == training_steps_max) + if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step: + set_size = audio_processor.set_size('validation') + total_accuracy = 0 + total_conf_matrix = None + for i in xrange(0, set_size, FLAGS.batch_size): + validation_fingerprints, validation_ground_truth = ( + audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0, + 0.0, 0, 'validation', sess)) + # Run a validation step and capture training summaries for TensorBoard + # with the `merged` op. + validation_summary, validation_accuracy, conf_matrix = sess.run( + [merged_summaries, evaluation_step, confusion_matrix], + feed_dict={ + fingerprint_input: validation_fingerprints, + ground_truth_input: validation_ground_truth, + dropout_prob: 1.0 + }) + validation_writer.add_summary(validation_summary, training_step) + batch_size = min(FLAGS.batch_size, set_size - i) + total_accuracy += (validation_accuracy * batch_size) / set_size + if total_conf_matrix is None: + total_conf_matrix = conf_matrix + else: + total_conf_matrix += conf_matrix + tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) + tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' % + (training_step, total_accuracy * 100, set_size)) + + # Save the model checkpoint periodically. + if (training_step % FLAGS.save_step_interval == 0 or + training_step == training_steps_max): + checkpoint_path = os.path.join(FLAGS.train_dir, + FLAGS.model_architecture + '.ckpt') + tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step) + saver.save(sess, checkpoint_path, global_step=training_step) + + set_size = audio_processor.set_size('testing') + tf.logging.info('set_size=%d', set_size) + total_accuracy = 0 + total_conf_matrix = None + for i in xrange(0, set_size, FLAGS.batch_size): + test_fingerprints, test_ground_truth = audio_processor.get_data( + FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess) + test_accuracy, conf_matrix = sess.run( + [evaluation_step, confusion_matrix], + feed_dict={ + fingerprint_input: test_fingerprints, + ground_truth_input: test_ground_truth, + dropout_prob: 1.0 + }) + batch_size = min(FLAGS.batch_size, set_size - i) + total_accuracy += (test_accuracy * batch_size) / set_size + if total_conf_matrix is None: + total_conf_matrix = conf_matrix + else: + total_conf_matrix += conf_matrix + tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) + tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100, + set_size)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--data_url', + type=str, + # pylint: disable=line-too-long + default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz', + # pylint: enable=line-too-long + help='Location of speech training data archive on the web.') + parser.add_argument( + '--data_dir', + type=str, + default='/tmp/speech_dataset/', + help="""\ + Where to download the speech training data to. + """) + parser.add_argument( + '--background_volume', + type=float, + default=0.1, + help="""\ + How loud the background noise should be, between 0 and 1. + """) + parser.add_argument( + '--background_frequency', + type=float, + default=0.8, + help="""\ + How many of the training samples have background noise mixed in. + """) + parser.add_argument( + '--silence_percentage', + type=float, + default=10.0, + help="""\ + How much of the training data should be silence. + """) + parser.add_argument( + '--unknown_percentage', + type=float, + default=10.0, + help="""\ + How much of the training data should be unknown words. + """) + parser.add_argument( + '--time_shift_ms', + type=float, + default=100.0, + help="""\ + Range to randomly shift the training audio by in time. + """) + parser.add_argument( + '--testing_percentage', + type=int, + default=10, + help='What percentage of wavs to use as a test set.') + parser.add_argument( + '--validation_percentage', + type=int, + default=10, + help='What percentage of wavs to use as a validation set.') + parser.add_argument( + '--sample_rate', + type=int, + default=16000, + help='Expected sample rate of the wavs',) + parser.add_argument( + '--clip_duration_ms', + type=int, + default=1000, + help='Expected duration in milliseconds of the wavs',) + parser.add_argument( + '--window_size_ms', + type=float, + default=20.0, + help='How long each spectrogram timeslice is',) + parser.add_argument( + '--window_stride_ms', + type=float, + default=10.0, + help='How long each spectrogram timeslice is',) + parser.add_argument( + '--dct_coefficient_count', + type=int, + default=40, + help='How many bins to use for the MFCC fingerprint',) + parser.add_argument( + '--how_many_training_steps', + type=str, + default='15000,3000', + help='How many training loops to run',) + parser.add_argument( + '--eval_step_interval', + type=int, + default=400, + help='How often to evaluate the training results.') + parser.add_argument( + '--learning_rate', + type=str, + default='0.001,0.0001', + help='How large a learning rate to use when training.') + parser.add_argument( + '--batch_size', + type=int, + default=100, + help='How many items to train with at once',) + parser.add_argument( + '--summaries_dir', + type=str, + default='/tmp/retrain_logs', + help='Where to save summary logs for TensorBoard.') + parser.add_argument( + '--wanted_words', + type=str, + default='yes,no,up,down,left,right,on,off,stop,go', + help='Words to use (others will be added to an unknown label)',) + parser.add_argument( + '--train_dir', + type=str, + default='/tmp/speech_commands_train', + help='Directory to write event logs and checkpoint.') + parser.add_argument( + '--save_step_interval', + type=int, + default=100, + help='Save model checkpoint every save_steps.') + parser.add_argument( + '--start_checkpoint', + type=str, + default='', + help='If specified, restore this pretrained model before any training.') + parser.add_argument( + '--model_architecture', + type=str, + default='conv', + help='What model architecture to use') + parser.add_argument( + '--check_nans', + type=bool, + default=False, + help='Whether to check for invalid numbers during processing') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a5a1f808a86..9e352fc52b6 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1133,6 +1133,15 @@ tf_gen_op_wrapper_private_py( ], ) +tf_gen_op_wrapper_private_py( + name = "audio_ops_gen", + require_shape_functions = True, + visibility = [ + "//learning/brain/python/ops:__pkg__", + "//tensorflow/contrib/framework:__pkg__", + ], +) + tf_gen_op_wrapper_private_py( name = "candidate_sampling_ops_gen", visibility = ["//learning/brain/python/ops:__pkg__"],