# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for data input for speech commands.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import numpy as np import tensorflow as tf from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio from tensorflow.examples.speech_commands import input_data from tensorflow.examples.speech_commands import models from tensorflow.python.framework import test_util from tensorflow.python.platform import test class InputDataTest(test.TestCase): def _getWavData(self): with self.cached_session() as sess: sample_data = tf.zeros([32000, 2]) wav_encoder = contrib_audio.encode_wav(sample_data, 16000) wav_data = self.evaluate(wav_encoder) return wav_data def _saveTestWavFile(self, filename, wav_data): with open(filename, "wb") as f: f.write(wav_data) def _saveWavFolders(self, root_dir, labels, how_many): wav_data = self._getWavData() for label in labels: dir_name = os.path.join(root_dir, label) os.mkdir(dir_name) for i in range(how_many): file_path = os.path.join(dir_name, "some_audio_%d.wav" % i) self._saveTestWavFile(file_path, wav_data) def _model_settings(self): return { "desired_samples": 160, "fingerprint_size": 40, "label_count": 4, "window_size_samples": 100, "window_stride_samples": 100, "fingerprint_width": 40, "preprocess": "mfcc", } def _runGetDataTest(self, preprocess, window_length_ms): tmp_dir = self.get_temp_dir() wav_dir = os.path.join(tmp_dir, "wavs") os.mkdir(wav_dir) self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) background_dir = os.path.join(wav_dir, "_background_noise_") os.mkdir(background_dir) wav_data = self._getWavData() for i in range(10): file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) self._saveTestWavFile(file_path, wav_data) model_settings = models.prepare_model_settings( 4, 16000, 1000, window_length_ms, 20, 40, preprocess) with self.cached_session() as sess: audio_processor = input_data.AudioProcessor( "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir) result_data, result_labels = audio_processor.get_data( 10, 0, model_settings, 0.3, 0.1, 100, "training", sess) self.assertEqual(10, len(result_data)) self.assertEqual(10, len(result_labels)) def testPrepareWordsList(self): words_list = ["a", "b"] self.assertGreater( len(input_data.prepare_words_list(words_list)), len(words_list)) def testWhichSet(self): self.assertEqual( input_data.which_set("foo.wav", 10, 10), input_data.which_set("foo.wav", 10, 10)) self.assertEqual( input_data.which_set("foo_nohash_0.wav", 10, 10), input_data.which_set("foo_nohash_1.wav", 10, 10)) @test_util.run_deprecated_v1 def testPrepareDataIndex(self): tmp_dir = self.get_temp_dir() self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10, self._model_settings(), tmp_dir) self.assertLess(0, audio_processor.set_size("training")) self.assertTrue("training" in audio_processor.data_index) self.assertTrue("validation" in audio_processor.data_index) self.assertTrue("testing" in audio_processor.data_index) self.assertEquals(input_data.UNKNOWN_WORD_INDEX, audio_processor.word_to_index["c"]) def testPrepareDataIndexEmpty(self): tmp_dir = self.get_temp_dir() self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0) with self.assertRaises(Exception) as e: _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10, self._model_settings(), tmp_dir) self.assertTrue("No .wavs found" in str(e.exception)) def testPrepareDataIndexMissing(self): tmp_dir = self.get_temp_dir() self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) with self.assertRaises(Exception) as e: _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10, 10, self._model_settings(), tmp_dir) self.assertTrue("Expected to find" in str(e.exception)) @test_util.run_deprecated_v1 def testPrepareBackgroundData(self): tmp_dir = self.get_temp_dir() background_dir = os.path.join(tmp_dir, "_background_noise_") os.mkdir(background_dir) wav_data = self._getWavData() for i in range(10): file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) self._saveTestWavFile(file_path, wav_data) self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10, self._model_settings(), tmp_dir) self.assertEqual(10, len(audio_processor.background_data)) def testLoadWavFile(self): tmp_dir = self.get_temp_dir() file_path = os.path.join(tmp_dir, "load_test.wav") wav_data = self._getWavData() self._saveTestWavFile(file_path, wav_data) sample_data = input_data.load_wav_file(file_path) self.assertIsNotNone(sample_data) def testSaveWavFile(self): tmp_dir = self.get_temp_dir() file_path = os.path.join(tmp_dir, "load_test.wav") save_data = np.zeros([16000, 1]) input_data.save_wav_file(file_path, save_data, 16000) loaded_data = input_data.load_wav_file(file_path) self.assertIsNotNone(loaded_data) self.assertEqual(16000, len(loaded_data)) @test_util.run_deprecated_v1 def testPrepareProcessingGraph(self): tmp_dir = self.get_temp_dir() wav_dir = os.path.join(tmp_dir, "wavs") os.mkdir(wav_dir) self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) background_dir = os.path.join(wav_dir, "_background_noise_") os.mkdir(background_dir) wav_data = self._getWavData() for i in range(10): file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) self._saveTestWavFile(file_path, wav_data) model_settings = { "desired_samples": 160, "fingerprint_size": 40, "label_count": 4, "window_size_samples": 100, "window_stride_samples": 100, "fingerprint_width": 40, "preprocess": "mfcc", } audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir) self.assertIsNotNone(audio_processor.wav_filename_placeholder_) self.assertIsNotNone(audio_processor.foreground_volume_placeholder_) self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_) self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_) self.assertIsNotNone(audio_processor.background_data_placeholder_) self.assertIsNotNone(audio_processor.background_volume_placeholder_) self.assertIsNotNone(audio_processor.output_) @test_util.run_deprecated_v1 def testGetDataAverage(self): self._runGetDataTest("average", 10) @test_util.run_deprecated_v1 def testGetDataAverageLongWindow(self): self._runGetDataTest("average", 30) @test_util.run_deprecated_v1 def testGetDataMfcc(self): self._runGetDataTest("mfcc", 30) @test_util.run_deprecated_v1 def testGetUnprocessedData(self): tmp_dir = self.get_temp_dir() wav_dir = os.path.join(tmp_dir, "wavs") os.mkdir(wav_dir) self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) model_settings = { "desired_samples": 160, "fingerprint_size": 40, "label_count": 4, "window_size_samples": 100, "window_stride_samples": 100, "fingerprint_width": 40, "preprocess": "mfcc", } audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir) result_data, result_labels = audio_processor.get_unprocessed_data( 10, model_settings, "training") self.assertEqual(10, len(result_data)) self.assertEqual(10, len(result_labels)) @test_util.run_deprecated_v1 def testGetFeaturesForWav(self): tmp_dir = self.get_temp_dir() wav_dir = os.path.join(tmp_dir, "wavs") os.mkdir(wav_dir) self._saveWavFolders(wav_dir, ["a", "b", "c"], 1) desired_samples = 1600 model_settings = { "desired_samples": desired_samples, "fingerprint_size": 40, "label_count": 4, "window_size_samples": 100, "window_stride_samples": 100, "fingerprint_width": 40, "average_window_width": 6, "preprocess": "average", } with self.cached_session() as sess: audio_processor = input_data.AudioProcessor( "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir) sample_data = np.zeros([desired_samples, 1]) for i in range(desired_samples): phase = i % 4 if phase == 0: sample_data[i, 0] = 0 elif phase == 1: sample_data[i, 0] = -1 elif phase == 2: sample_data[i, 0] = 0 elif phase == 3: sample_data[i, 0] = 1 test_wav_path = os.path.join(tmp_dir, "test_wav.wav") input_data.save_wav_file(test_wav_path, sample_data, 16000) results = audio_processor.get_features_for_wav(test_wav_path, model_settings, sess) spectrogram = results[0] self.assertEqual(1, spectrogram.shape[0]) self.assertEqual(16, spectrogram.shape[1]) self.assertEqual(11, spectrogram.shape[2]) self.assertNear(0, spectrogram[0, 0, 0], 0.1) self.assertNear(200, spectrogram[0, 0, 5], 0.1) def testGetFeaturesRange(self): model_settings = { "preprocess": "average", } features_min, _ = input_data.get_features_range(model_settings) self.assertNear(0.0, features_min, 1e-5) def testGetMfccFeaturesRange(self): model_settings = { "preprocess": "mfcc", } features_min, features_max = input_data.get_features_range(model_settings) self.assertLess(features_min, features_max) if __name__ == "__main__": test.main()